kathrinse / be_great

A novel approach for synthesizing tabular data using pretrained large language models
MIT License
254 stars 43 forks source link

Invalid sampling with Heloc and Sick datasets #21

Closed dionman closed 1 year ago

dionman commented 1 year ago

I'm trying to rerun the experiments of the paper on the Heloc and Sick datasets. I'm using distilgpt2, fine-tuning for 200 epochs and using a temperature T=0.7 over sampling. For both datasets I am unable to get valid samples, as all sampled datapoints have missing features. I realised that Heloc has an all-missing column which I removed from the input data and retrained the model. This still didn't work. Is there any further preprocessing or hyperparameter tuning required for these two datasets, beyond what's reported in the paper?

unnir commented 1 year ago

Hi,

Thank you for reporting the issue, could you please share your code and a data sample with the feature names?

dionman commented 1 year ago
from be_great import GReaT
model = GReaT(llm='distilgpt2', epochs=200, batch_size=16)
model.fit(data)
synthetic_data = model.sample(n_samples=500, k=1000, max_length=100)

with a data sample being

age,sex,on thyroxine,query on thyroxine,on antithyroid medication,sick,pregnant,thyroid surgery,I131 treatment,query hypothyroid,query hyperthyroid,lithium,goitre,tumor,hypopituitary,psych,TSH measured,TSH,T3 measured,T3,TT4 measured,TT4,T4U measured,T4U,FTI measured,FTI,TBG measured,referral source,binaryClass
41,F,f,f,f,f,f,f,f,f,f,f,f,f,f,f,t,1.3,t,2.5,t,125,t,1.14,t,109,f,SVHC,P
23,F,f,f,f,f,f,f,f,f,f,f,f,f,f,f,t,4.1,t,2,t,102,f,?,f,?,f,other,P
46,M,f,f,f,f,f,f,f,f,f,f,f,f,f,f,t,0.98,f,?,t,109,t,0.91,t,120,f,other,P
70,F,t,f,f,f,f,f,f,f,f,f,f,f,f,f,t,0.16,t,1.9,t,175,f,?,f,?,f,other,P
70,F,f,f,f,f,f,f,f,f,f,f,f,f,f,f,t,0.72,t,1.2,t,61,t,0.87,t,70,f,SVI,P
18,F,t,f,f,f,f,f,f,f,f,f,f,f,f,f,t,0.03,f,?,t,183,t,1.3,t,141,f,other,P
59,F,f,f,f,f,f,f,f,f,f,f,f,f,f,f,f,?,f,?,t,72,t,0.92,t,78,f,other,P
80,F,f,f,f,f,f,f,f,f,f,f,f,f,f,f,t,2.2,t,0.6,t,80,t,0.7,t,115,f,SVI,P

read to a pandas dataframe using pd.read_csv(data_file)

unnir commented 1 year ago

Thank you for providing the code and data sample!

In our experiments, we observed that the usage of the question mark "?" could be quite an informative symbol for pretrained LLMs, so we replaced it with NaN.

Here is our data preparation pipeline for the Sick dataset:


import numpy as np
import pandas as pd
from scipy.io.arff import loadarff 

raw_data = loadarff('dataset_38_sick.arff')
df = pd.DataFrame(raw_data[0])

cat_col = ['sex', 'on_thyroxine', 'query_on_thyroxine',
           'on_antithyroid_medication', 'sick','pregnant', 
           'thyroid_surgery', 'I131_treatment', 'query_hypothyroid',
           'query_hyperthyroid', 
           'lithium', 'goitre', 'tumor', 'hypopituitary', 'psych', 
           'TSH_measured', 'T3_measured', 
           'TT4_measured','T4U_measured', 'FTI_measured', 'TBG_measured',
           'referral_source', 'Class',
          ]

for i_col in cat_col:
    df[i_col] = df[i_col].str.decode("utf-8")

df.age = df.age.astype(str)

def age_preproc(x):
    try:
        return int(float(x))
    except:
        return x

df.age = df.age.apply(age_preproc)
df = df.replace(np.NaN, 'None')
dionman commented 1 year ago

thanks. Have you uploaded / are you planning to upload the trained models on HuggingFace?

unnir commented 1 year ago

Not yet, but if you want it please write me an email, I will send you the weights.

You can find my email here: https://uni-tuebingen.de/fakultaeten/mathematisch-naturwissenschaftliche-fakultaet/fachbereiche/informatik/lehrstuehle/data-science-analytics/team/vadim-borisov/