worldbank / REaLTabFormer

A suite of auto-regressive and Seq2Seq (sequence-to-sequence) transformer models for tabular and relational synthetic data generation.
https://worldbank.github.io/REaLTabFormer/
MIT License
200 stars 23 forks source link

Possible Improvements for CPU inference #49

Open australDream opened 10 months ago

australDream commented 10 months ago

Hi, I am currently trying to improve the inference time. However for a given batch size of 512 sample generation the inference time of the gpu is twice as the cpu. Any idea on it ?

child_samples = model.sample(n_samples=512, input_unique_ids=query[self.join_on], input_df=query.drop(self.join_on, axis=1), gen_batch=512,device=self.device)

Note that the model is relational and no frozen encoder given. Moreover if there is a general tips for cpu inference for the RealTabformer I am eager to learn. Thanks for the neat repo. Cheers