worldbank / REaLTabFormer

A suite of auto-regressive and Seq2Seq (sequence-to-sequence) transformer models for tabular and relational synthetic data generation.
https://worldbank.github.io/REaLTabFormer/
MIT License
200 stars 23 forks source link

Generating Balanced Synthesized Data #79

Open erland-ramadhan opened 3 months ago

erland-ramadhan commented 3 months ago

Hey, is it possible to generate the balanced synthesized data even though the realtabformer model is trained on imbalanced data (the proportion is even up to 4 to 96). How do I do that?

CTGAN, TVAE, and even be_great are able to do this simply by: model.sample(n_samples, start_col=target_col, start_col_dist={'Yes':0.5, 'No':0.5})

avsolatorio commented 2 months ago

Hello @erland-ramadhan , can you check if the seed_input parameter in the model.sample method of REaLTabFormer satisfy your need?

By the way, there is a prerequisite to using this. The target_col you want to condition must be at the beginning of the table you are synthesizing.

It could look something like below:

yes_samples = model.sample(n_samples // 2, seed_input={target_col: "Yes"})
no_samples = model.sample(n_samples // 2, seed_input={target_col: "No"})

samples = pd.concat([yes_samples, no_samples])