worldbank / REaLTabFormer

A suite of auto-regressive and Seq2Seq (sequence-to-sequence) transformer models for tabular and relational synthetic data generation.
https://worldbank.github.io/REaLTabFormer/
MIT License
200 stars 23 forks source link

Out of memory exception on tabular model with 25k rows and 37 columns #59

Closed efstathios-chatzikyriakidis closed 6 months ago

efstathios-chatzikyriakidis commented 7 months ago

Hi @avsolatorio,

I have a case with ~25k rows and 37 columns. Mixed data types with categorical, numerical and also some have high cardinality while other low cardinality. Also some columns have large number of NAs.

When I train the tabular model I get a memory error as it needs more than 50GB, also bootstrap threshold estimation is very slow.

Do you have any insights on why this happens or how this can be solved, is there any hyperparameter I can use to solve this?

Thanks!

efstathios-chatzikyriakidis commented 7 months ago

Regarding this as a follow up I am presenting the code I use:

training_execution_params = {
   "table_training":{
      "n_epochs":2,
      "n_bootstrap_rounds":100,
      "batch_size":8,
      "n_gradient_accumulation_steps":2
   }
}

table_training_params = training_execution_params['table_training']

parent_model = REaLTabFormer(model_type="tabular",
                             batch_size=table_training_params['batch_size'],
                             epochs=table_training_params['n_epochs'],
                             gradient_accumulation_steps=table_training_params['n_gradient_accumulation_steps'],
                             logging_steps=1000,
                             save_strategy="epoch",
                             checkpoints_dir = MODEL_RUN_DIRECTORY_PATH / f'{table_name}_checkpoints')

trainer = parent_model.fit(df=entity_tables_dataframes[table_name],
                           num_bootstrap=table_training_params['n_bootstrap_rounds'],
                           device=get_device())
avsolatorio commented 6 months ago

Hello @efstathios-chatzikyriakidis , I assume this is related to https://github.com/worldbank/REaLTabFormer/issues/60. And I think you've got a solution to this already with the hold-out validation dataset. :)

efstathios-chatzikyriakidis commented 6 months ago

Yes, thank you!