worldbank / REaLTabFormer

A suite of auto-regressive and Seq2Seq (sequence-to-sequence) transformer models for tabular and relational synthetic data generation.
https://worldbank.github.io/REaLTabFormer/
MIT License
203 stars 23 forks source link

Model save only works for Tabular model type. #8

Closed echatzikyriakidis closed 1 year ago

echatzikyriakidis commented 1 year ago

@avsolatorio Hi!

I have tried to use checkpoints_dir parameter for a relational model and it seems that no checkpoints are saved for this type. Also, in the code I see that only Tabular model type is handled. Is this a bug?

https://github.com/avsolatorio/REaLTabFormer/blob/bf1a38ef8f202372956ac57a363289c505967982/src/realtabformer/realtabformer.py#L1450

avsolatorio commented 1 year ago

Hello @echatzikyriakidis, the checkpoints_dir argument is directly passed to the training arguments of the Transformers trainer, so I'm unsure why you do not see the checkpoints:

https://github.com/avsolatorio/REaLTabFormer/blob/bf1a38ef8f202372956ac57a363289c505967982/src/realtabformer/realtabformer.py#L199

The model itself should be saved, see:

https://github.com/avsolatorio/REaLTabFormer/blob/bf1a38ef8f202372956ac57a363289c505967982/src/realtabformer/realtabformer.py#L1448

The line you're referring to saves the artefacts (checkpoints) generated during the sensitivity mechanism only available to the tabular model.

echatzikyriakidis commented 1 year ago

Hi @avsolatorio,

Yes, you are right. Indeed the checkpoints from HF are saved in the directory. At some point I thought they are not saved, my mistake.

Thank you!