sdv-dev / SDV

Synthetic data generation for tabular data
https://docs.sdv.dev/sdv
Other
2.21k stars 291 forks source link

Add ability to use the `gaussian_kde` in HMA #1602

Closed JanJacekJaniszewski closed 2 months ago

JanJacekJaniszewski commented 9 months ago

Environment Details

Please indicate the following details about the environment in which you found the bug:

Error Description

Setting table parameters in HMASynthesizer causes a num_rows error when fitting the model.

Steps to reproduce

Input

from sdv.multi_table import HMASynthesizer
from sdv.datasets.demo import download_demo

real_data, real_metadata = download_demo(
    modality='multi_table',
    dataset_name='walmart'
)

synthesizer = HMASynthesizer(real_metadata)

# The synthesizer runs perfectly fine when this is not included
synthesizer.set_table_parameters(
    table_name='depts',
    table_parameters={
        'enforce_min_max_values': True,
        'enforce_rounding': True,
        'default_distribution': 'beta',
        'numerical_distributions': { 
            'Weekly_Sales': 'gaussian_kde'}
    }
)
################################

synthesizer.fit(real_data)

Output

Preprocess Tables: 100%|██████████| 3/3 [00:03<00:00,  1.03s/it]

Learning relationships:
(1/2) Tables 'stores' and 'features' ('Store'): 100%|██████████| 45/45 [00:11<00:00,  3.91it/s]
(2/2) Tables 'stores' and 'depts' ('Store'): 100%|██████████| 45/45 [01:12<00:00,  1.61s/it]
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File ~/McK_Internal/ai-powered-dataops/venv/lib/python3.9/site-packages/pandas/core/indexes/base.py:3653, in Index.get_loc(self, key)
   3652 try:
-> 3653     return self._engine.get_loc(casted_key)
   3654 except KeyError as err:

File ~/McK_Internal/ai-powered-dataops/venv/lib/python3.9/site-packages/pandas/_libs/index.pyx:147, in pandas._libs.index.IndexEngine.get_loc()

File ~/McK_Internal/ai-powered-dataops/venv/lib/python3.9/site-packages/pandas/_libs/index.pyx:176, in pandas._libs.index.IndexEngine.get_loc()

File pandas/_libs/hashtable_class_helper.pxi:7080, in pandas._libs.hashtable.PyObjectHashTable.get_item()

File pandas/_libs/hashtable_class_helper.pxi:7088, in pandas._libs.hashtable.PyObjectHashTable.get_item()

KeyError: '__depts__Store__num_rows'

The above exception was the direct cause of the following exception:

KeyError                                  Traceback (most recent call last)
Cell In[8], line 26
     14 synthesizer.set_table_parameters(
     15     table_name='depts',
     16     table_parameters={
   (...)
     22     }
     23 )
     24 ################################
---> 26 synthesizer.fit(real_data) # Add your data here (perhaps you have to load it first)
     28 # Generate a sample now (scale is the multiplier for how many samples should be generated; i.e., len(synthetic_data) = scale * len(YOURDATA)
     29 #synthetic_data = synthesizer.sample(scale=1)

File ~/McK_Internal/ai-powered-dataops/venv/lib/python3.9/site-packages/sdv/multi_table/base.py:330, in BaseMultiTableSynthesizer.fit(self, data)
    328 processed_data = self.preprocess(data)
    329 self._print(text='\n', end='')
--> 330 self.fit_processed_data(processed_data)

File ~/McK_Internal/ai-powered-dataops/venv/lib/python3.9/site-packages/sdv/multi_table/base.py:313, in BaseMultiTableSynthesizer.fit_processed_data(self, processed_data)
    306 def fit_processed_data(self, processed_data):
    307     """Fit this model to the transformed data.
    308 
    309     Args:
    310         processed_data (dict):
    311             Dictionary mapping each table name to a preprocessed ``pandas.DataFrame``.
    312     """
--> 313     augmented_data = self._augment_tables(processed_data)
    314     self._model_tables(augmented_data)
    315     self._fitted = True

File ~/McK_Internal/ai-powered-dataops/venv/lib/python3.9/site-packages/sdv/multi_table/hma.py:236, in HMASynthesizer._augment_tables(self, processed_data)
    234 for table_name in processed_data:
    235     if not parent_map.get(table_name):
--> 236         self._augment_table(augmented_data[table_name], augmented_data, table_name)
    238 LOGGER.info('Augmentation Complete')
    239 return augmented_data

File ~/McK_Internal/ai-powered-dataops/venv/lib/python3.9/site-packages/sdv/multi_table/hma.py:166, in HMASynthesizer._augment_table(self, table, tables, table_name)
    164 table = table.merge(extension, how='left', right_index=True, left_index=True)
    165 num_rows_key = f'__{child_name}__{foreign_key}__num_rows'
--> 166 table[num_rows_key] = table[num_rows_key].fillna(0)
    167 self._max_child_rows[num_rows_key] = table[num_rows_key].max()
    168 tables[table_name] = table

File ~/McK_Internal/ai-powered-dataops/venv/lib/python3.9/site-packages/pandas/core/frame.py:3761, in DataFrame.__getitem__(self, key)
   3759 if self.columns.nlevels > 1:
   3760     return self._getitem_multilevel(key)
-> 3761 indexer = self.columns.get_loc(key)
   3762 if is_integer(indexer):
   3763     indexer = [indexer]

File ~/McK_Internal/ai-powered-dataops/venv/lib/python3.9/site-packages/pandas/core/indexes/base.py:3655, in Index.get_loc(self, key)
   3653     return self._engine.get_loc(casted_key)
   3654 except KeyError as err:
-> 3655     raise KeyError(key) from err
   3656 except TypeError:
   3657     # If we have a listlike key, _check_indexing_error will raise
   3658     #  InvalidIndexError. Otherwise we fall through and re-raise
   3659     #  the TypeError.
   3660     self._check_indexing_error(key)

KeyError: '__depts__Store__num_rows'
npatki commented 9 months ago

Hi @JanJacekJaniszewski, thanks for filing this issue with the detailed code an stack trace.

The 'gaussian_kde' distribution is not actually supported by the HMASynthesizer due to an algorithmic incompatibility. (The HMA algorithm is designed to work only with parametric distributions that have a pre-determined and set # of parameters.)

We've just updated the HMA docs with this clarification.

image

I would suggest using any of the other distributions such as 'beta', 'norm', etc.

Next Steps

  1. The error message you're seeing is not really related to the overall root cause. We can file a new issue to attach a more descriptive error message when you try to use a HMASynthesizer.set_table_parameters with a Gaussian KDE.
  2. I suggest we can repurpose this issue to be a feature request for supporting Gaussian KDE with HMA. To help us prioritize, I'm curious what your use case is for this? Are you trying to improve the synthetic data in some way, and how would this be helpful to your overall project?
npatki commented 9 months ago

Update: I am re-purposing this issue as a feature request.

Linked, you can see that the other related issue has now been closed. In the upcoming SDV release, we will provide a better error message when using 'gaussian_kde' for the HMA.

npatki commented 2 months ago

An update on this issue -- due to the nature of the HMA algorithm, we will be unable to accommodate using the gaussian_kde with certain columns of the HMASynthesizer. Note that doing so would also increase the compute time, which can be high for certain schemas in HMA.

Instead, we recommend using the HSASynthesizer instead. The HSA algorithm can handle complex schemas as well as non-parametric kde distributions for individual columns.

Do note that this synthesizer is only available for paid SDV plans. To learn more, you can visit our support page. Thanks.