Closed JMGaljaard closed 1 year ago
Dear @JMGaljaard
Thank you for suggesting an improvement to our framework! Especially, I thank you for reporting it in a such understandable way, I appreciate it.
We will be happy to receive & accept a PR :)
Is someone working on this? Otherwise I would be happy to create a PR based on the suggestion :)
@Madnex please go ahead! We will be happy to receive your PR
Dear authors,
Let me start by thanking you for the open-source release of GReaT. I found an implementation detail about the generation of samples, especially on larger datasets.
Problem Description
Looking at the GPU utilization I found that the CPU workload (everything outside of sampling the model) takes increasingly longer. (using
nvtop
, GPU utilization becomes worse with more/higher sampling iterations).Proposed Solution
Digging in the code I found that the accumulator (
df_gen
) and generated (pd.DataFrame(td)
) data frames are concatenated in each iteration.https://github.com/kathrinse/be_great/blob/c568617763ba954fb39fc6b6e222e3abaef0886a/be_great/great.py#LL147C21-L147C21
https://github.com/kathrinse/be_great/blob/c568617763ba954fb39fc6b6e222e3abaef0886a/be_great/great_utils.py#L97
This incurs
O(N^2)
overhead (each time memory is allocated for a newDataFrame
that can contain all rows). This can be resolved by creating a list of data frames andconcat
enating them at the end of the generation process. For example:for
GReaT.sample
this would require a minor change, similar to the following:The
_convert_text_to_tabular_data
can be improved similarly by making it return a DataFrame that is constructed from a list of dictionaries.This way for a dataset containing 20K+ samples, generation time went from 40+ minutes to about 3 minutes. Smaller datasets also seem to benefit, but this is less pronounced as the overhead grows linearly with the sampling iteration.
Example implementation
Looking at related work, it seems like the RealTabFormers implementation provides an example of this.
https://github.com/worldbank/REaLTabFormer/blob/bf1a38ef8f202372956ac57a363289c505967982/src/realtabformer/rtf_sampler.py#L610-L674
Side note
Likely this could also (slightly) improve GReaT's performance in Appendix B.5 of your paper for inference/generation.