sdv-dev / SDV

Synthetic data generation for tabular data
https://docs.sdv.dev/sdv
Other
2.32k stars 305 forks source link

In `PARSynthesizer` I cannot pass in datetime context (`InvalidDataError` during fitting) #1485

Closed npatki closed 2 months ago

npatki commented 1 year ago

I'm filing this issue on behalf of a Slack user.

Environment Details

Error Description

I am trying to use a PARSynthesizer on multi-sequence data that has a datetime context column (birthdate). While it is able to train, there is an error when sampling synthetic data.

Steps to reproduce

import pandas as pd
from sdv.metadata import SingleTableMetadata
from sdv.sequential import PARSynthesizer

# create 2 fake sequences for different users
# the 'birthdate' column is a context column because it does not change
data = pd.DataFrame(data={
    'user_id': ['ID_00']*5 + ['ID_01']*5,
    'birthdate': ['1995-05-06']*5 + ['1982-01-21']*5,
    'timestamp': ['2023-06-21', '2023-06-22', '2023-06-23', '2023-06-24' , '2023-06-25']*2,
    'heartrate': [67, 66, 68, 65, 64, 80, 82, 91, 88, 84]
})

metadata = SingleTableMetadata.load_from_dict({
    'columns': {
        'user_id': { 'sdtype': 'id', 'regex_format': 'ID_[0-9]{2}' },
        'birthdate': { 'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d' },
        'timestamp': { 'sdtype': 'datetime' , 'datetime_format': '%Y-%m-%d' },
        'heartrate': { 'sdtype': 'numerical' }
    },
    'sequence_key': 'user_id',
    'sequence_index': 'timestamp'
})

synth = PARSynthesizer(
    metadata,
    epochs=50,
    verbose=True,
    context_columns=['birthdate'])

synth.fit(data)
synth.sample(num_sequences=2)

Output

TypeError: Cannot cast DatetimeArray to dtype float64

Full Stack Trace

stack_trace.txt

npatki commented 1 year ago

Workaround

I can cast the datetime column to a Unix timestamp. Then the column becomes numerical and the PARSynthesizer accepts it as a context column.

# Input the name of your datetime context column
COLUMN_NAME = 'birthdate'

# create a copy of the data and convert the column into a unix timestamp (integer) using astype(int)
data_copy = data.copy()
data_copy[COLUMN_NAME] = pd.to_datetime(data_copy[COLUMN_NAME], format='%Y-%m-%d').astype(int)

# now make sure the column is specified numerical instead of datetime in the metadata
metadata.update_column(
  column_name=COLUMN_NAME,
  sdtype='numerical')

# now it works
synthesizer = PARSynthesizer(
    metadata,
    context_columns=[COLUMN_NAME])
synthesizer.fit(data_copy)
synthetic_data = synthesizer.sample(num_sequences=2)

# be sure to cast the column back to the desired format!
synthetic_data[COLUMN_NAME] = pd.to_datetime(synthetic_data[COLUMN_NAME], unit='ns').dt.date
Ng-ms commented 8 months ago

Hi i am facing a simler problem when passing the datetime as context columns, this is the error massage I am having raise InvalidDataError(errors) sdv.errors.InvalidDataError: The provided data does not match the metadata: Invalid values found for datetime column 'dateofpred': [1.3263264e+18, 1.3372128e+18, 1.3721184e+18, '+ 8 more']. and when I change the metadata to numerical as suggested I am getting back a numerical values instead of date in the sythnic data

npatki commented 8 months ago

Hi @Ng-ms, I don't think this problem is related to the overall topic of this issue (passing in datetime columns into context during sampling). Would you mind filing a new issue for this for us to take a look?

You can use this link. It would be helpful if you can fill out the requested information so that we can better replicate and get the root cause:

Thanks

npatki commented 5 months ago

Note that as of SDV 1.12.0, this issue is still present but the error has changed. I'll updated the title to clarify.

Pasting the new error and stack trace below (using the same code as in the first comment of the issue).

Error: This doesn't make sense because the birthdate is present as datetime strings such as '1995-05-06', not as timestamps.

InvalidDataError: The provided data does not match the metadata:
Invalid values found for datetime column 'birthdate': [3.804192e+17, 7.997184e+17].

Stack trace (see below).

stack_trace.txt

Additional Context

Interestingly, both of the commands below pass without issue.

metadata.validate()
metadata.validate_data(data)