worldbank / REaLTabFormer

A suite of auto-regressive and Seq2Seq (sequence-to-sequence) transformer models for tabular and relational synthetic data generation.
https://worldbank.github.io/REaLTabFormer/
MIT License
203 stars 23 forks source link

Speeding the training on mixed data set - categorical data, numerical and text. #26

Closed vinay-k12 closed 1 year ago

vinay-k12 commented 1 year ago

Trying to train the model on custom data which has various categorical feature with very high diversity like City, text features and numerical feature. Data size is small - 380K.

But the training was never starting! It is stuck at this for few hours!

image

How to improve the training?

avsolatorio commented 1 year ago

Hi @vinay-k12 , I think it's because of the cardinality of your data that makes the bootstrapping part not progress. What you could try is not to use the automated termination based on the bootstrap statistic; instead, you can use a validation sample.

Try:

# Use 20% of the data as a validation set early-stopping.
rtf_model = REaLTabFormer(
    model_type="tabular",
    gradient_accumulation_steps=4,
    logging_steps=100,
    train_size=0.8,
)

# Fit the model without sensitivity bootstrapping.
rtf_model.fit(df, n_critic=0)

This will fit the data directly. Hope this helps!