worldbank / REaLTabFormer

A suite of auto-regressive and Seq2Seq (sequence-to-sequence) transformer models for tabular and relational synthetic data generation.
https://worldbank.github.io/REaLTabFormer/
MIT License
200 stars 23 forks source link

_validate_get_device() could be nice to be called also in model.sample() and model.predict() #32

Closed echatzikyriakidis closed 1 year ago

echatzikyriakidis commented 1 year ago

Hi!

I have noticed that _validate_get_device() is called only on model.fit() but it could be nice if is called also in model.sample() and model.predict() so that there is no need to pass device argument in the methods when no cuda is available to use CPU. model.fit() already calls it and there is no need to pass device, it automatically detects what to use.

avsolatorio commented 1 year ago

@echatzikyriakidis, thanks for noting this! I pushed a patch to use _validate_get_device in the other interfaces. 😀

echatzikyriakidis commented 1 year ago

Thank you!