sdv-dev / SDV

Synthetic data generation for tabular data
https://docs.sdv.dev/sdv
Other
2.23k stars 293 forks source link

PAR DiagnosticReport not 1.0 with float categorical columns #1910

Closed frances-h closed 4 days ago

frances-h commented 3 months ago

Environment Details

Please indicate the following details about the environment in which you found the bug:

Error Description

When running PAR with categorical columns that are floats, PAR does not stick to the original categories when sampling. This leads to a very low diagnostic score for 'Data Validity' due to the CategoryAdherence metric failing.

Steps to reproduce

from sdv.datasets.demo import download_demo
from sdv.sequential import PARSynthesizer
from sdv.evaluation.single_table import run_diagnostic

data, metadata = download_demo('sequential', 'nasdaq100_2019')
data['category'] = [100.0 if i % 2 == 0 else 50.0 for i in data.index]
metadata.add_column('category', sdtype='categorical')

synth = PARSynthesizer(metadata)
synth.fit(data)
sampled = synth.sample(2)

report = run_diagnostic(data, sampled, metadata)
npatki commented 3 months ago

Workaround

If anyone is running into this, here is a suggested workaround:

  1. Identify any categorical columns (in the metadata) that are actually represented as numbers in your data (ints, floats, etc.)
  2. Cast these columns as objects before inputting them into the PARSynthesizer.
  3. At the end when you get synthetic data, cast them back as ints, floats, etc.

Here is a code snippet that accomplishes the below. Replace the list CAT_COLUMN_NAMES with the list of your column names.

CAT_COLUMN_NAMES = ['ColA', 'ColB', ... ]

data = <your pandas DataFrame>
metadata = <your SingleTableMetadata object>

# cast the categorical columns to strings
for col_name in CAT_COLUMN_NAMES:
  data[col_name] = data[col_name].astype('object')

# now proceed with modeling and sampling as usual
synthesizer = PARSynthesizer(metadata)
synthesizer.fit(data)
synthetic_data = synthesizer.sample(num_sequences=10)

# (optional) cast the categorical columns back to floats
for col_name in CAT_COLUMN_NAMES:
  try:
    synthetic_data[col_name] = synthetic_data[col_name].astype('float')
  except:
    print('Column name', col_name, 'could not be converted back to a float')
    continue