sdv-dev / SDV

Synthetic data generation for tabular data
https://docs.sdv.dev/sdv
Other
2.21k stars 287 forks source link

Error when applying `FixedCombinations` constraint on a child table with multiple parents in `HMASynthesizer` #2087

Closed pvk-developer closed 1 week ago

pvk-developer commented 1 week ago

Jun 21: This bug was first found and described on Slack by Sam Wachtel -- thanks!

Environment Details

Please indicate the following details about the environment in which you found the bug:

Error Description

When using the HMASynthesizer with a FixedCombinations constraint on a child table that has multiple parent tables, the synthesizer fails to drop a column created by the fixed combinations, resulting in the following error:

File ~/Projects/sdv-dev/SDV/sdv/multi_table/hma.py:660, in HMASynthesizer._find_parent_ids(self, child_table, parent_table, child_name, parent_name, foreign_key)
    657 parent_table = parent_table.set_index(primary_key)
    658 num_rows = parent_table[f'__{child_name}__{foreign_key}__num_rows'].copy()
--> 660 likelihoods = self._get_likelihoods(child_table, parent_table, child_name, foreign_key)
    661 return likelihoods.apply(self._find_parent_id, axis=1, num_rows=num_rows)

File ~/Projects/sdv-dev/SDV/sdv/multi_table/hma.py:619, in HMASynthesizer._get_likelihoods(self, table_rows, parent_rows, table_name, foreign_key)
    616 if transformed.index.name:
    617     table_rows = table_rows.set_index(transformed.index.name)
--> 619 table_rows = pd.concat([transformed, table_rows.drop(columns=transformed.columns)], axis=1)
    620 for parent_id, row in parent_rows.iterrows():
    621     parameters = self._extract_parameters(row, table_name, foreign_key)

....
KeyError: "['department#office'] not found in axis"

Steps to reproduce

import pandas as pd

# Creating a dataset
data = {
    'users': pd.DataFrame({
        'user_id': [1, 2, 3],
        'name': ['Alice', 'Bob', 'John'],
    }),
    'records': pd.DataFrame({
        'record_id': ['record_a', 'record_b', 'record_c', 'record_d'],
        'user_id': [1, 2, 2, 1],
        'score': [85, 92, 78, 88],
        'location_id': ['A', 'B', 'C', 'D'],
        'department': ['HR', 'IT', 'HR', 'Finance'],
        'office': ['Boston HQ', 'NYC Office', 'LA Office', 'Chicago HQ']
    }),
    'locations': pd.DataFrame({
        'location_id': ['A', 'B', 'C', 'D'],
        'city': ['Boston', 'New York', 'Los Angeles', 'Chicago'],
        'country': ['USA', 'USA', 'USA', 'USA']
    })
}

# Create metadata for the dataset
from sdv.metadata import MultiTableMetadata
metadata = MultiTableMetadata()
metadata.detect_from_dataframes(data)

metadata.update_column('users', 'user_id', sdtype='id')
metadata.update_column('records', 'record_id', sdtype='id')
metadata.update_column('records', 'user_id', sdtype='id')
metadata.update_column('records', 'location_id', sdtype='id')
metadata.update_column('locations', 'location_id', sdtype='id')
metadata.set_primary_key('users', 'user_id')
metadata.set_primary_key('locations', 'location_id')
metadata.add_relationship('users', 'records', 'user_id', 'user_id')
metadata.add_relationship('locations', 'records', 'location_id', 'location_id')

# Creating HMASynthesizer
from sdv.multi_table import HMASynthesizer
synthesizer = HMASynthesizer(metadata)

synthesizer.add_constraints(constraints=[
    {
        'constraint_class': 'FixedCombinations',
        'table_name': 'records',
        'constraint_parameters': {
            'column_names': ['department', 'office']
        }
    }
])

synthesizer.fit(data)
synthesizer.sample(1)