kathrinse / be_great

A novel approach for synthesizing tabular data using pretrained large language models
MIT License
276 stars 46 forks source link

About how the GReaT model handles the generation of missing values (NaN) #39

Closed mjbooo closed 1 year ago

mjbooo commented 1 year ago

Hello, authors! I wanted to express my gratitude for the excellent work you've done!

I do have a question regarding the generation of missing values. I'm a bit puzzled about how the GReaT model handles the generation of missing values (NaN) in the current implementation.

When I input a DataFrame, GReaT automatically converts it into text like 'column1 is value1, column2 is value2, ...'. If value1 happens to be a null value, the resulting text becomes 'column1 is None, column2 is value2, ...'. Then, it goes into the LLM backbone (GPT2).

However, I've noticed that the 'column1 is None' part gets dropped by the code below (# Remove rows with flawed numerical values). This happens because when applying pd.to_numeric with the 'coerce' option to 'None' (as a string), it raises an error. Consequently, the corresponding value is converted to null. However, by selecting only non-null values using .notnull(), all rows with null values are dropped here.

I recently experimented with the .fit() and .sample() functions to generate a DataFrame with several columns containing missing values (e.g., 'sick' dataset).

If I've made any mistakes in my understanding, please let me know.

Once again, thank you for your assistance!


        with tqdm(total=n_samples) as pbar:
            already_generated = 0
            _cnt = 0
            try:
                while n_samples > already_generated:
                    start_tokens = great_start.get_start_tokens(k)
                    start_tokens = torch.tensor(start_tokens).to(device)

                    # Generate tokens
                    tokens = self.model.generate(
                        input_ids=start_tokens,
                        max_length=max_length,
                        do_sample=True,
                        temperature=temperature,
                        pad_token_id=50256,
                    )

                    # Convert tokens back to tabular data
                    text_data = _convert_tokens_to_text(tokens, self.tokenizer)
                    df_gen = _convert_text_to_tabular_data(text_data, self.columns)

                    # Remove rows with flawed numerical values
                    for i_num_cols in self.num_cols:
                        df_gen = df_gen[
                            pd.to_numeric(df_gen[i_num_cols], errors="coerce").notnull()
                        ]

                    df_gen[self.num_cols] = df_gen[self.num_cols].astype(float)

                    # Remove rows with missing values
                    df_gen = df_gen.drop(df_gen[df_gen.isna().any(axis=1)].index)

                    dfs.append(df_gen)
                    already_generated += len(dfs[-1])

                    # Update process bar
                    pbar.update(len(dfs[-1]))

                    # Check if we actually generating synth samples and if not break everything
                    _cnt += 1
                    if _cnt > 13 and already_generated == 0:  # (:
                        raise Exception("Breaking the generation loop!")
unnir commented 1 year ago

Dear @mjbooo,

Thank you for spotting and reporting the issue.

Certainly, we need to address this. I will work on resolving it when I have time. Alternatively, I'm open to receiving a PR.

unnir commented 1 year ago

Thank you again for reporting the issue!

The code has been updated. GReaT now generates samples with missing values if they exist in the original dataset. To skip missing values, add drop_nan=True in the sample function:

synthetic_data = model.sample(n_samples=200, drop_nan=True)