Skip to content

Commit dd69da6

Browse files
authored
Transformer Pre-Trainining: Add early stop + fix eval set👷‍♀️ (#409)
1 parent a038bab commit dd69da6

File tree

4 files changed

+48
-48
lines changed

4 files changed

+48
-48
lines changed

src/otc/models/fttransformer.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -614,9 +614,10 @@ def forward(
614614
if self.dropout is not None:
615615
attention_probs = self.dropout(attention_probs)
616616

617-
self.save_attn(attention_probs)
618-
if attention_probs.requires_grad:
619-
attention_probs.register_hook(self.save_attn_gradients)
617+
# comment out for training
618+
# self.save_attn(attention_probs)
619+
# if attention_probs.requires_grad:
620+
# attention_probs.register_hook(self.save_attn_gradients)
620621

621622
x = attention_probs @ self._reshape(v)
622623
x = (

src/otc/models/objective.py

+3
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def __init__(
157157

158158
self._clf: BaseEstimator
159159
self._callbacks = CallbackContainer([SaveCallback(), PrintCallback()])
160+
self._pretrain = pretrain
160161

161162
super().__init__(x_train, y_train, x_val, y_val, name, pretrain)
162163

@@ -235,6 +236,7 @@ def __call__(self, trial: optuna.Trial) -> float:
235236
"feature_tokenizer": FeatureTokenizer(**feature_tokenizer_kwargs), # type: ignore # noqa: E501
236237
"cat_features": self._cat_features,
237238
"cat_cardinalities": self._cat_cardinalities,
239+
"d_token": d_token,
238240
}
239241

240242
optim_params = {"lr": lr, "weight_decay": weight_decay}
@@ -245,6 +247,7 @@ def __call__(self, trial: optuna.Trial) -> float:
245247
optim_params=optim_params,
246248
dl_params=dl_params,
247249
callbacks=self._callbacks, # type: ignore # noqa: E501
250+
pretrain=self._pretrain,
248251
)
249252

250253
self._clf.fit(

src/otc/models/train_model.py

+1-22
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,6 @@
7575
@click.option(
7676
"--pretrain/--no-pretrain", default=False, help="Flag to activate pretraining."
7777
)
78-
@click.option(
79-
"--sample",
80-
type=click.FloatRange(0, 1),
81-
default=1,
82-
help="Sampling factor applied to train and validation set.",
83-
)
8478
def main(
8579
trials: int,
8680
seed: int,
@@ -89,7 +83,6 @@ def main(
8983
id: str,
9084
dataset: str,
9185
pretrain: bool,
92-
sample: float,
9386
) -> None:
9487
"""
9588
Start study.
@@ -102,7 +95,6 @@ def main(
10295
id (str): id of study.
10396
dataset (str): name of data set.
10497
pretrain (bool): whether to pretrain model.
105-
sample (float): sampling factor.
10698
"""
10799
logger = logging.getLogger(__name__)
108100
warnings.filterwarnings("ignore", category=ExperimentalWarning)
@@ -171,19 +163,6 @@ def main(
171163
y_val = x_val["buy_sell"]
172164
x_val.drop(columns=["buy_sell"], inplace=True)
173165

174-
if sample < 1.0:
175-
# sample down train data
176-
x_train = x_train.sample(frac=sample, random_state=set_seed(seed)).reset_index(
177-
drop=True
178-
)
179-
y_train = y_train.iloc[x_train.index]
180-
181-
# sample down validation data
182-
x_val = x_val.sample(frac=sample, random_state=set_seed(seed)).reset_index(
183-
drop=True
184-
)
185-
y_val = y_val.iloc[x_val.index]
186-
187166
# pretrain training activated
188167
has_label = (y_train != 0).all()
189168
if pretrain and has_label:
@@ -251,7 +230,7 @@ def main(
251230
"dataset": dataset,
252231
"seed": seed,
253232
"pretrain": pretrain,
254-
"sample": sample,
233+
"sample": 1.0,
255234
}
256235
)
257236

src/otc/models/transformer_classifier.py

+40-23
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
import pandas as pd
1616
import torch
1717
from sklearn.base import BaseEstimator, ClassifierMixin
18-
from sklearn.utils.multiclass import check_classification_targets
19-
from sklearn.utils.validation import check_array, check_is_fitted, check_X_y
18+
from sklearn.utils.validation import check_array, check_is_fitted
2019
from torch import nn, optim
2120

2221
from otc.data.dataloader import TabDataLoader
@@ -38,7 +37,7 @@ class TransformerClassifier(BaseEstimator, ClassifierMixin):
3837
"""
3938

4039
epochs_pretrain = 20
41-
epochs_finetune = 1
40+
epochs_finetune = 20
4241

4342
def __init__(
4443
self,
@@ -246,6 +245,7 @@ def fit( # noqa: C901
246245

247246
if self.pretrain:
248247

248+
print("start pre-training...")
249249
mask = y == 0
250250

251251
# isolate unlabelled
@@ -260,10 +260,14 @@ def fit( # noqa: C901
260260
train_loader_pretrain = self.array_to_dataloader_pretrain(
261261
X_unlabelled, y_unlabelled
262262
)
263-
val_loader_pretrain = self.array_to_dataloader_pretrain(
264-
X_unlabelled, y_unlabelled
263+
264+
# use in-sample instead of validation set, if None is provided
265+
X_val, y_val = (
266+
eval_set if eval_set is not None else (X_unlabelled, y_unlabelled)
265267
)
266268

269+
val_loader_pretrain = self.array_to_dataloader_pretrain(X_val, y_val)
270+
267271
# free up memory
268272
del X_unlabelled, y_unlabelled
269273
gc.collect()
@@ -314,24 +318,30 @@ def fit( # noqa: C901
314318
optimizer=optimizer, warmup=warmup_steps, max_iters=max_steps
315319
)
316320

317-
criterion = nn.BCEWithLogitsLoss()
321+
# keep track of val loss and do early stopping
322+
early_stopping = EarlyStopping(patience=10)
323+
324+
# mean bce with logits loss
325+
criterion = nn.BCEWithLogitsLoss(reduction="mean")
318326

319327
step = 0
328+
best_accuracy = -1.0
329+
320330
for epoch in range(self.epochs_pretrain):
321331

322332
# perform training
323333
loss_in_epoch_train = 0
324334

325335
batch = 0
326336

327-
for x_cat, x_cont, masks in train_loader_pretrain:
337+
for x_cat, x_cont, mask in train_loader_pretrain:
328338

329339
self.clf.train()
330340
optimizer.zero_grad()
331341

332342
with torch.autocast(device_type="cuda", dtype=torch.float16):
333343
logits = self.clf(x_cat, x_cont)
334-
train_loss = criterion(logits, masks.float()) # type: ignore[union-attr] # noqa: E501
344+
train_loss = criterion(logits, mask.float()) # type: ignore[union-attr] # noqa: E501
335345

336346
scaler.scale(train_loss).backward()
337347
scaler.step(optimizer)
@@ -353,31 +363,45 @@ def fit( # noqa: C901
353363
correct = 0
354364

355365
with torch.no_grad():
356-
for x_cat, x_cont, masks in val_loader_pretrain:
366+
for x_cat, x_cont, mask in val_loader_pretrain:
357367

358368
# for my implementation
359369
logits = self.clf(x_cat, x_cont)
360-
val_loss = criterion(logits, masks.float()) # type: ignore[union-attr] # noqa: E501
361-
370+
val_loss = criterion(logits, mask.float()) # type: ignore[union-attr] # noqa: E501
362371
loss_in_epoch_val += val_loss.item()
363372

373+
# accuracy
374+
# adapted from here, but over columns + rows https://github.com/puhsu/tabular-dl-pretrain-objectives/blob/3f503d197867c341b4133efcafd3243eb5bb93de/bin/mask.py#L440 # noqa: E501
375+
hard_predictions = torch.zeros_like(logits, dtype=torch.long)
376+
hard_predictions[logits > 0] = 1
377+
# sum columns and rows
378+
correct += (hard_predictions.bool() == mask).sum()
379+
364380
batch += 1
365381

366382
# loss average over all batches
367383
train_loss_all = loss_in_epoch_train / len(train_loader_pretrain)
368384
val_loss_all = loss_in_epoch_val / len(val_loader_pretrain)
385+
# correct / (rows * columns)
386+
val_accuracy = correct / (X_val.shape[0] * X_val.shape[1])
387+
388+
print(f"train loss: {train_loss}")
389+
print(f"val loss: {val_loss}")
390+
print(f"val accuracy: {val_accuracy}")
369391

370392
self._stats_pretrain_epoch.append(
371393
{
372394
"train_loss": train_loss_all,
373395
"val_loss": val_loss_all,
396+
"val_accuracy": val_accuracy,
374397
"step": step,
375398
"epoch": epoch,
376399
}
377400
)
378401

379-
print(f"train loss: {train_loss}")
380-
print(f"val loss: {val_loss}")
402+
if best_accuracy < val_accuracy:
403+
self._checkpoint_write()
404+
best_accuracy = val_accuracy
381405

382406
# https://discuss.huggingface.co/t/clear-gpu-memory-of-transformers-pipeline/18310/2
383407
del train_loader_pretrain, val_loader_pretrain
@@ -389,17 +413,10 @@ def fit( # noqa: C901
389413
self.clf.to(self.dl_params["device"])
390414

391415
# start finetuning beneath
392-
check_classification_targets(y)
393-
X, y = check_X_y(X, y, multi_output=False, accept_sparse=False)
416+
print("start finetuning...")
394417

395418
# use in-sample instead of validation set, if None is provided
396-
if eval_set:
397-
X_val, y_val = eval_set
398-
X_val, y_val = check_X_y(
399-
X_val, y_val, multi_output=False, accept_sparse=False
400-
)
401-
else:
402-
X_val, y_val = X, y
419+
X_val, y_val = eval_set if eval_set is not None else (X, y)
403420

404421
# save for accuracy calculation
405422
len_x_val = len(X_val)
@@ -529,8 +546,8 @@ def fit( # noqa: C901
529546
)
530547
loss_in_epoch_val += val_loss.item()
531548

532-
# print(f"[{epoch}-{val_batch}] val loss: {val_loss}")
533549
val_batch += 1
550+
534551
# loss average over all batches
535552
train_loss_all = loss_in_epoch_train / len(train_loader_finetune)
536553
val_loss_all = loss_in_epoch_val / len(val_loader_finetune)

0 commit comments

Comments
 (0)