15
15
import pandas as pd
16
16
import torch
17
17
from sklearn .base import BaseEstimator , ClassifierMixin
18
- from sklearn .utils .multiclass import check_classification_targets
19
- from sklearn .utils .validation import check_array , check_is_fitted , check_X_y
18
+ from sklearn .utils .validation import check_array , check_is_fitted
20
19
from torch import nn , optim
21
20
22
21
from otc .data .dataloader import TabDataLoader
@@ -38,7 +37,7 @@ class TransformerClassifier(BaseEstimator, ClassifierMixin):
38
37
"""
39
38
40
39
epochs_pretrain = 20
41
- epochs_finetune = 1
40
+ epochs_finetune = 20
42
41
43
42
def __init__ (
44
43
self ,
@@ -246,6 +245,7 @@ def fit( # noqa: C901
246
245
247
246
if self .pretrain :
248
247
248
+ print ("start pre-training..." )
249
249
mask = y == 0
250
250
251
251
# isolate unlabelled
@@ -260,10 +260,14 @@ def fit( # noqa: C901
260
260
train_loader_pretrain = self .array_to_dataloader_pretrain (
261
261
X_unlabelled , y_unlabelled
262
262
)
263
- val_loader_pretrain = self .array_to_dataloader_pretrain (
264
- X_unlabelled , y_unlabelled
263
+
264
+ # use in-sample instead of validation set, if None is provided
265
+ X_val , y_val = (
266
+ eval_set if eval_set is not None else (X_unlabelled , y_unlabelled )
265
267
)
266
268
269
+ val_loader_pretrain = self .array_to_dataloader_pretrain (X_val , y_val )
270
+
267
271
# free up memory
268
272
del X_unlabelled , y_unlabelled
269
273
gc .collect ()
@@ -314,24 +318,30 @@ def fit( # noqa: C901
314
318
optimizer = optimizer , warmup = warmup_steps , max_iters = max_steps
315
319
)
316
320
317
- criterion = nn .BCEWithLogitsLoss ()
321
+ # keep track of val loss and do early stopping
322
+ early_stopping = EarlyStopping (patience = 10 )
323
+
324
+ # mean bce with logits loss
325
+ criterion = nn .BCEWithLogitsLoss (reduction = "mean" )
318
326
319
327
step = 0
328
+ best_accuracy = - 1.0
329
+
320
330
for epoch in range (self .epochs_pretrain ):
321
331
322
332
# perform training
323
333
loss_in_epoch_train = 0
324
334
325
335
batch = 0
326
336
327
- for x_cat , x_cont , masks in train_loader_pretrain :
337
+ for x_cat , x_cont , mask in train_loader_pretrain :
328
338
329
339
self .clf .train ()
330
340
optimizer .zero_grad ()
331
341
332
342
with torch .autocast (device_type = "cuda" , dtype = torch .float16 ):
333
343
logits = self .clf (x_cat , x_cont )
334
- train_loss = criterion (logits , masks .float ()) # type: ignore[union-attr] # noqa: E501
344
+ train_loss = criterion (logits , mask .float ()) # type: ignore[union-attr] # noqa: E501
335
345
336
346
scaler .scale (train_loss ).backward ()
337
347
scaler .step (optimizer )
@@ -353,31 +363,45 @@ def fit( # noqa: C901
353
363
correct = 0
354
364
355
365
with torch .no_grad ():
356
- for x_cat , x_cont , masks in val_loader_pretrain :
366
+ for x_cat , x_cont , mask in val_loader_pretrain :
357
367
358
368
# for my implementation
359
369
logits = self .clf (x_cat , x_cont )
360
- val_loss = criterion (logits , masks .float ()) # type: ignore[union-attr] # noqa: E501
361
-
370
+ val_loss = criterion (logits , mask .float ()) # type: ignore[union-attr] # noqa: E501
362
371
loss_in_epoch_val += val_loss .item ()
363
372
373
+ # accuracy
374
+ # adapted from here, but over columns + rows https://github.com/puhsu/tabular-dl-pretrain-objectives/blob/3f503d197867c341b4133efcafd3243eb5bb93de/bin/mask.py#L440 # noqa: E501
375
+ hard_predictions = torch .zeros_like (logits , dtype = torch .long )
376
+ hard_predictions [logits > 0 ] = 1
377
+ # sum columns and rows
378
+ correct += (hard_predictions .bool () == mask ).sum ()
379
+
364
380
batch += 1
365
381
366
382
# loss average over all batches
367
383
train_loss_all = loss_in_epoch_train / len (train_loader_pretrain )
368
384
val_loss_all = loss_in_epoch_val / len (val_loader_pretrain )
385
+ # correct / (rows * columns)
386
+ val_accuracy = correct / (X_val .shape [0 ] * X_val .shape [1 ])
387
+
388
+ print (f"train loss: { train_loss } " )
389
+ print (f"val loss: { val_loss } " )
390
+ print (f"val accuracy: { val_accuracy } " )
369
391
370
392
self ._stats_pretrain_epoch .append (
371
393
{
372
394
"train_loss" : train_loss_all ,
373
395
"val_loss" : val_loss_all ,
396
+ "val_accuracy" : val_accuracy ,
374
397
"step" : step ,
375
398
"epoch" : epoch ,
376
399
}
377
400
)
378
401
379
- print (f"train loss: { train_loss } " )
380
- print (f"val loss: { val_loss } " )
402
+ if best_accuracy < val_accuracy :
403
+ self ._checkpoint_write ()
404
+ best_accuracy = val_accuracy
381
405
382
406
# https://discuss.huggingface.co/t/clear-gpu-memory-of-transformers-pipeline/18310/2
383
407
del train_loader_pretrain , val_loader_pretrain
@@ -389,17 +413,10 @@ def fit( # noqa: C901
389
413
self .clf .to (self .dl_params ["device" ])
390
414
391
415
# start finetuning beneath
392
- check_classification_targets (y )
393
- X , y = check_X_y (X , y , multi_output = False , accept_sparse = False )
416
+ print ("start finetuning..." )
394
417
395
418
# use in-sample instead of validation set, if None is provided
396
- if eval_set :
397
- X_val , y_val = eval_set
398
- X_val , y_val = check_X_y (
399
- X_val , y_val , multi_output = False , accept_sparse = False
400
- )
401
- else :
402
- X_val , y_val = X , y
419
+ X_val , y_val = eval_set if eval_set is not None else (X , y )
403
420
404
421
# save for accuracy calculation
405
422
len_x_val = len (X_val )
@@ -529,8 +546,8 @@ def fit( # noqa: C901
529
546
)
530
547
loss_in_epoch_val += val_loss .item ()
531
548
532
- # print(f"[{epoch}-{val_batch}] val loss: {val_loss}")
533
549
val_batch += 1
550
+
534
551
# loss average over all batches
535
552
train_loss_all = loss_in_epoch_train / len (train_loader_finetune )
536
553
val_loss_all = loss_in_epoch_val / len (val_loader_finetune )
0 commit comments