sdv-dev / SDV

Synthetic data generation for tabular data
https://docs.sdv.dev/sdv
Other
2.21k stars 287 forks source link

HMA sampling crashes when unknown sdtype detected for numerical column #2064

Closed amontanez24 closed 5 hours ago

amontanez24 commented 2 weeks ago

Environment Details

Please indicate the following details about the environment in which you found the bug:

Error Description

If a numerical column is detected as unknown in the metadata, then sampling seems to fail because it can't cast the generated value which is a string with the prefix sdv-pii- to an int.

Steps to reproduce

import pandas as pd
import faker
import numpy as np

fake = faker.Faker()

table1 = pd.DataFrame({
    'name': [fake.name() for i in range(20)],
    'salary': np.random.randint(20_000, 250_000, 20),
    'age': np.random.randint(18, 70, 20),
    'address': [fake.address() for i in range(20)]
})
table2 = pd.DataFrame({
    'company': [fake.company() for i in range(20)],
    'employee_count': np.random.randint(15, 4000, 20),
    'revenue': np.random.randint(100_000, 4_000_000_000)
})

tables_dict = {'people': table1, 'company': table2}

from sdv.metadata import MultiTableMetadata

metadata = MultiTableMetadata()
metadata.detect_from_dataframes(tables_dict)

from sdv.multi_table import HMASynthesizer

synth = HMASynthesizer(metadata)
synth.fit(tables_dict)
synth.sample(1)

The code above produces

ValueError                                Traceback (most recent call last)
[<ipython-input-4-f0bab0e0379c>](https://localhost:8080/#) in <cell line: 5>()
      3 synth = HMASynthesizer(metadata)
      4 synth.fit(tables_dict)
----> 5 synth.sample(1)

9 frames
[/usr/local/lib/python3.10/dist-packages/pandas/core/dtypes/astype.py](https://localhost:8080/#) in _astype_nansafe(arr, dtype, copy, skipna)
    136     if copy or is_object_dtype(arr.dtype) or is_object_dtype(dtype):
    137         # Explicit copy, or required since NumPy can't view from / to object.
--> 138         return arr.astype(dtype, copy=True)
    139 
    140     return arr.astype(dtype, copy=copy)

ValueError: invalid literal for int() with base 10: 'sdv-pii-cs7y0'