kathrinse / be_great

A novel approach for synthesizing tabular data using pretrained large language models
MIT License
276 stars 46 forks source link

Improving generation speed #23

Closed JMGaljaard closed 1 year ago

JMGaljaard commented 1 year ago

Dear authors,

Let me start by thanking you for the open-source release of GReaT. I found an implementation detail about the generation of samples, especially on larger datasets.

Problem Description

Looking at the GPU utilization I found that the CPU workload (everything outside of sampling the model) takes increasingly longer. (using nvtop, GPU utilization becomes worse with more/higher sampling iterations).

Proposed Solution

Digging in the code I found that the accumulator (df_gen) and generated (pd.DataFrame(td)) data frames are concatenated in each iteration.

https://github.com/kathrinse/be_great/blob/c568617763ba954fb39fc6b6e222e3abaef0886a/be_great/great.py#LL147C21-L147C21

https://github.com/kathrinse/be_great/blob/c568617763ba954fb39fc6b6e222e3abaef0886a/be_great/great_utils.py#L97

This incurs O(N^2) overhead (each time memory is allocated for a new DataFrame that can contain all rows). This can be resolved by creating a list of data frames and concatenating them at the end of the generation process. For example:

for GReaT.sample this would require a minor change, similar to the following:

# Create an accumulation list for generated data
dfs = [] 
...
while n > already_generated:
    ...
    df_gen = _convert_text_to_tabular_data(text_data, df_gen)
    ...
    dfs.append(df_gen)
    already_generated += len(dfs[-1])
    pbar.update(len(dfs[-1]))

df_gen = pd.concat(dfs)
df_gen = df_gen.reset_index(drop=True)    
...

The _convert_text_to_tabular_data can be improved similarly by making it return a DataFrame that is constructed from a list of dictionaries.

def _convert_text_to_tabular_data(text: tp.List[str], df_gen: pd.DataFrame) -> pd.DataFrame:
    ...
    generated = []
    ...
    for t in text:
        ...
        generated.append(td)
    gen_df = pd.DataFrame(generated)

This way for a dataset containing 20K+ samples, generation time went from 40+ minutes to about 3 minutes. Smaller datasets also seem to benefit, but this is less pronounced as the overhead grows linearly with the sampling iteration.

Example implementation

Looking at related work, it seems like the RealTabFormers implementation provides an example of this.

https://github.com/worldbank/REaLTabFormer/blob/bf1a38ef8f202372956ac57a363289c505967982/src/realtabformer/rtf_sampler.py#L610-L674

Side note

Likely this could also (slightly) improve GReaT's performance in Appendix B.5 of your paper for inference/generation.

unnir commented 1 year ago

Dear @JMGaljaard

Thank you for suggesting an improvement to our framework! Especially, I thank you for reporting it in a such understandable way, I appreciate it.

We will be happy to receive & accept a PR :)

Madnex commented 1 year ago

Is someone working on this? Otherwise I would be happy to create a PR based on the suggestion :)

unnir commented 1 year ago

@Madnex please go ahead! We will be happy to receive your PR