worldbank / REaLTabFormer

A suite of auto-regressive and Seq2Seq (sequence-to-sequence) transformer models for tabular and relational synthetic data generation.
https://worldbank.github.io/REaLTabFormer/
MIT License
203 stars 23 forks source link

Is the calculation of num_train_epochs correct? #12

Closed echatzikyriakidis closed 1 year ago

echatzikyriakidis commented 1 year ago

Hi @avsolatorio,

I am looking the code of _train_with_sensitivity() and I can't understand why we calculate the num_train_epochs in that way

https://github.com/avsolatorio/REaLTabFormer/blob/bf1a38ef8f202372956ac57a363289c505967982/src/realtabformer/realtabformer.py#L692

Assuming, we run for 100 epochs and n_critic is 5 we are going to have the following pairs of [p_epoch, num_train_epochs]

p_epoch, num_train_epochs 0, 5 5, 10 10, 15 15, 20 ... 80, 85 85, 90 90, 95 95, 100

In the following two lines we set the num_train_epochs:

https://github.com/avsolatorio/REaLTabFormer/blob/bf1a38ef8f202372956ac57a363289c505967982/src/realtabformer/realtabformer.py#L698

https://github.com/avsolatorio/REaLTabFormer/blob/bf1a38ef8f202372956ac57a363289c505967982/src/realtabformer/realtabformer.py#L705

Is that correct? On first iteration where p_epoch=0 and num_train_epochs=5 it is OK to train the model for 5 epochs. But in the next iteration where p_epoch=5 and num_train_epochs=10 why we should continue training the model for 10 epochs? Shouldn't we just contrinue training it for 5 more epochs? At the extreme in the last iteration where p_epoch=95 we train the model for num_train_epochs=100 epochs?

Thanks.

avsolatorio commented 1 year ago

Hello @echatzikyriakidis, great question! I have subclassed the huggingface trainer, see:

https://github.com/avsolatorio/REaLTabFormer/blob/bf1a38ef8f202372956ac57a363289c505967982/src/realtabformer/rtf_trainer.py#L53

It accepts a target_epoch argument corresponding to the normal number of epochs that you want to train the model with, if sensitivity training is not performed.

In the case where we train the model with sensitivity, we intermittently pause the model training at each critic round n_critic. After the end of each critic round, a checkpoint will be saved. Then, the ResumableTrainer will use this and simply skip the epochs already processed, see:

https://github.com/avsolatorio/REaLTabFormer/blob/bf1a38ef8f202372956ac57a363289c505967982/src/realtabformer/realtabformer.py#L708

So, in your example p_epoch=95, the trainer will go through the data 95 times without training the model. Then, it will train the model for 5 epochs to reach the target of 100 epochs.

I implemented it this way because I noticed that the learning rate is dependent with the total number of epochs. When this approach is not implemented, we will get noisy training if the model is intermittently trained.

echatzikyriakidis commented 1 year ago

Hi again!

Awesome 😎 I knew that I was missing something in the implementation as this would be a major bug 😂 otherwise.

Thanks!