worldbank / REaLTabFormer

A suite of auto-regressive and Seq2Seq (sequence-to-sequence) transformer models for tabular and relational synthetic data generation.
https://worldbank.github.io/REaLTabFormer/
MIT License
212 stars 24 forks source link

Early stopping with sensitivity vs validation loss metric and the effects on synthetic data quality. #60

Closed efstathios-chatzikyriakidis closed 9 months ago

efstathios-chatzikyriakidis commented 9 months ago

Hi @avsolatorio,

Hope everything is well!

I have noticed that by specifying n_critic=0 when training a tabular model is a way to disable train with sensitivity. In my use case and dataset the threshold estimation of sensitivity metric was too slow and needed large amounts of memory (more than 200GB or RAM). So, I have tried to replace it with a classic one based on validation loss. So, currently I use the following code:

parent_model = REaLTabFormer(model_type="tabular",
                             batch_size=8,
                             epochs=30,
                             gradient_accumulation_steps=1,
                             logging_steps=25,
                             save_strategy="epoch",          # CLASSIC EARLY STOPPING
                             evaluation_strategy="epoch",    # CLASSIC EARLY STOPPING
                             train_size=0.8,                 # CLASSIC EARLY STOPPING
                             early_stopping_patience=5,      # CLASSIC EARLY STOPPING
                             early_stopping_threshold=0,     # CLASSIC EARLY STOPPING
                             checkpoints_dir = MODEL_RUN_DIRECTORY_PATH / f'{table_name}_checkpoints')

trainer = parent_model.fit(df=table_data,
                           n_critic=0,    # CLASSIC EARLY STOPPING
                           device='cuda')

trainer.state.save_to_json(MODEL_RUN_DIRECTORY_PATH / f'{table_name}_checkpoints' / "trainer_state.json")

parent_model.save(MODEL_RUN_DIRECTORY_PATH / f"{table_name}_model")

My question is: how much do we lose in quality if we use classic early stopping instead of sensitivity-based stopping criteria? I am not asking for an exact number, just to have an idea from your experience. Is training with sensitivity very different from classic early stopping in terms of the quality of synthetic data? I am asking this because, for example, the relational model doesn't use training with sensitivity. Furthermore, I have noticed a boost in performance, since sensitivity threshold estimation at the beginning of training is very slow in my data (even with many CPU cores when parallelization is used). I am thinking of using classic early stopping and looking for some validation that this will not significantly decrease the quality of the synthetic data. Of course, I will check it, but here I am asking for your insight first to get validation.

Lastly, here is a train-validation loss plot training with early stopping with validation loss (just an example):

training-plot

Thanks!

avsolatorio commented 9 months ago

Hello @efstathios-chatzikyriakidis , indeed, the critic feature, as currently implemented, can get quite resource-intensive when fitting high dimensional data or large datasets. The critic's value is only marginal if you have a large dataset. But it is very useful for training models on small datasets.

I am sharing the ablation results below, showing the effect of the overfitting mechanism (critic).

image

So, yes, you can definitely use a hold-out validation set when you have a sufficiently large dataset! :)

efstathios-chatzikyriakidis commented 9 months ago

Thank you for your fast reply @avsolatorio.

Yes, indeed the interesting part for me is that when mr=0 the difference is not very large. In the other cases it is larger. I think the default mr is 0 which is what it happens in my case. So, assuming that what I already have is a baseline, I won't be too far from it. Of course I will test it.

Thank you so much!