worldbank / REaLTabFormer

A suite of auto-regressive and Seq2Seq (sequence-to-sequence) transformer models for tabular and relational synthetic data generation.
https://worldbank.github.io/REaLTabFormer/
MIT License
203 stars 23 forks source link

CPU OOM during tokenization - Tabular format #23

Closed mohanvrk closed 1 year ago

mohanvrk commented 1 year ago

I have the following issue in training the model in tabular format (30 million samples).

I used REalTabFormer for generating synthetic data in tabular format. The following configuration helped us in training and generating 1M samples. It's working extremely well.

CPU: 16 CORE - 60 GB RAM GPU: 16 GB Nvidia-T4

However, when I want to use it for large datasets, for example 30M samples, on the same machine, the RAM crashes with OOM error during the tokenization stage itself.

avsolatorio commented 1 year ago

Hi @mohanvrk , I have a hunch that this may be related to the computation of the sensitivity statistic via bootstrapping. Since you are working with a large dataset, you can directly use some percentage of the data as a validation set for early stopping.

You can try 10% to 20% of the data. Below is a snippet that uses 90% for training and 10% for the validation.

# Use 20% of the data as a validation set early-stopping.
rtf_model = REaLTabFormer(
    model_type="tabular",
    gradient_accumulation_steps=4,
    logging_steps=100,
    train_size=0.9,
)

# Fit the model without sensitivity bootstrapping.
rtf_model.fit(df, n_critic=0)
mohanvrk commented 1 year ago

Hi @mohanvrk , I have a hunch that this may be related to the computation of the sensitivity statistic via bootstrapping. Since you are working with a large dataset, you can directly use some percentage of the data as a validation set for early stopping.

You can try 10% to 20% of the data. Below is a snippet that uses 90% for training and 10% for the validation.

# Use 20% of the data as a validation set early-stopping.
rtf_model = REaLTabFormer(
    model_type="tabular",
    gradient_accumulation_steps=4,
    logging_steps=100,
    train_size=0.9,
)

# Fit the model without sensitivity bootstrapping.
rtf_model.fit(df, n_critic=0)
mohanvrk commented 1 year ago

Hi @mohanvrk , I have a hunch that this may be related to the computation of the sensitivity statistic via bootstrapping. Since you are working with a large dataset, you can directly use some percentage of the data as a validation set for early stopping. You can try 10% to 20% of the data. Below is a snippet that uses 90% for training and 10% for the validation.

# Use 20% of the data as a validation set early-stopping.
rtf_model = REaLTabFormer(
    model_type="tabular",
    gradient_accumulation_steps=4,
    logging_steps=100,
    train_size=0.9,
)

# Fit the model without sensitivity bootstrapping.
rtf_model.fit(df, n_critic=0)

Hi @avsolatorio,

Thanks for the reply and suggestions.

Still the model training is failing with the suggested information.

I tried to add some print statements in the data_utilities file and found that the training is failing just before the mapping of token.