sdv-dev / SDV

Synthetic data generation for tabular data
https://docs.sdv.dev/sdv
Other
2.32k stars 304 forks source link

HMA Synthesizer's `scale` parameter doesn't work for small values #2045

Closed srinify closed 3 months ago

srinify commented 4 months ago

Environment Details

Please indicate the following details about the environment in which you found the bug:

Error Description

Sampling from HMA Synthesizer using a small scale value can result in an error. This seems to happen when scale root table's row count results in a float (e.g. 0.01 10 rows = 0.1 rows requested).

Workaround

Until this is fixed, we recommend increasing the scale parameter to a higher value.

For example, if you have 100 rows in your parent table, we recommend using a scale value greater than 0.01 (so you at least get 1 row back).

Proposed Solution

  1. Neha's proposal in the original issue was to set the minimum size of root tables used for sampling to 1 row. So even if scale is very low and the resulting requested row count is under 1, the user will still receive 1 row from the root (parent) table.

  2. Additionally, she pointed out that if cardinality won't be accurate in many cases if scale is this low. So we should also show a warning and encourage the user to increase the scale parameter:

Warning: The 'scale' parameter it too small. Some tables may have only 0 or 1 rows. For better quality data,
please choose a larger scale.

Steps to reproduce

from sdv.datasets.demo import download_demo
from sdv.multi_table import HMASynthesizer

data, metadata = download_demo(
    modality='multi_table',
    dataset_name='fake_hotels'
)

synthesizer = HMASynthesizer(metadata)
synthesizer.fit(data)

synthetic_data = synthesizer.sample(scale=0.01)

Running this code throws a KeyError:

KeyError                                  Traceback (most recent call last)
Cell In[1], line 12
      9 synthesizer = HMASynthesizer(metadata)
     10 synthesizer.fit(data)
---> 12 synthetic_data = synthesizer.sample(scale=0.01)

File ~/.pyenv/versions/3.10.0/envs/sdv_latest/lib/python3.10/site-packages/sdv/multi_table/base.py:486, in BaseMultiTableSynthesizer.sample(self, scale)
    482     raise SynthesizerInputError(
    483         f"Invalid parameter for 'scale' ({scale}). Please provide a number that is >0.0.")
    485 with self._set_temp_numpy_seed(), disable_single_table_logger():
--> 486     sampled_data = self._sample(scale=scale)
    488 total_rows = 0
    489 total_columns = 0

File ~/.pyenv/versions/3.10.0/envs/sdv_latest/lib/python3.10/site-packages/sdv/sampling/hierarchical_sampler.py:281, in BaseHierarchicalSampler._sample(self, scale)
    279     LOGGER.info(f'Sampling {num_rows} rows from table {table}')
    280     sampled_data[table] = self._sample_rows(synthesizer, num_rows)
--> 281     self._sample_children(table_name=table, sampled_data=sampled_data, scale=scale)
    283 added_relationships = set()
    284 for relationship in self.metadata.relationships:

File ~/.pyenv/versions/3.10.0/envs/sdv_latest/lib/python3.10/site-packages/sdv/sampling/hierarchical_sampler.py:192, in BaseHierarchicalSampler._sample_children(self, table_name, sampled_data, scale)
    179 """Recursively sample the children of a table.
    180
    181 This method will loop through the children of a table and sample rows for that child for
   (...)
    189         A dictionary mapping table names to sampled tables (pd.DataFrame).
    190 """
    191 for child_name in self.metadata._get_child_map()[table_name]:
--> 192     self._enforce_table_size(child_name, table_name, scale, sampled_data)
    194     if child_name not in sampled_data:  # Sample based on only 1 parent
    195         for _, row in sampled_data[table_name].iterrows():

File ~/.pyenv/versions/3.10.0/envs/sdv_latest/lib/python3.10/site-packages/sdv/sampling/hierarchical_sampler.py:146, in BaseHierarchicalSampler._enforce_table_size(self, child_name, table_name, scale, sampled_data)
    144 min_rows = getattr(self, '_min_child_rows', {num_rows_key: 0})[num_rows_key]
    145 max_rows = self._max_child_rows[num_rows_key]
--> 146 key_data = sampled_data[table_name][num_rows_key].fillna(0).round()
    147 sampled_data[table_name][num_rows_key] = key_data.clip(min_rows, max_rows).astype(int)
    149 while sum(sampled_data[table_name][num_rows_key]) != total_num_rows:

File ~/.pyenv/versions/3.10.0/envs/sdv_latest/lib/python3.10/site-packages/pandas/core/frame.py:4102, in DataFrame.__getitem__(self, key)
   4100 if self.columns.nlevels > 1:
   4101     return self._getitem_multilevel(key)
-> 4102 indexer = self.columns.get_loc(key)
   4103 if is_integer(indexer):
   4104     indexer = [indexer]

File ~/.pyenv/versions/3.10.0/envs/sdv_latest/lib/python3.10/site-packages/pandas/core/indexes/range.py:417, in RangeIndex.get_loc(self, key)
    415         raise KeyError(key) from err
    416 if isinstance(key, Hashable):
--> 417     raise KeyError(key)
    418 self._check_indexing_error(key)
    419 raise KeyError(key)

KeyError: '__guests__hotel_id__num_rows'