sdv-dev / SDMetrics

Metrics to evaluate quality and efficacy of synthetic datasets.
MIT License
201 stars 45 forks source link

Datetime columns set to Object pandas dtype breaks LSTMDetection #584

Closed srinify closed 1 month ago

srinify commented 3 months ago

Environment Details

SDMetrics version: 0.14.0 (a user reported 0.11 as well)

Error Description

If your pandas DataFrame contains datetime column(s) that are stored using the object dtype (instead of datetime), this breaks LSTMDetection. This is because object and datetime fields are transformed and handled differently. The error message describes a failed one-hot encoding attempt.

Originally raised and


For now, manually cast your datetime columns to the datetime dtype before using LSTMDetection. One quick way is using pandas.to_datetime():

df['date_col_1'] = pd.to_datetime(df['date_col_1'])

Steps to reproduce

GitHub Gist Internal Colab Notebook

Ideal Solution

If the user-provided metadata has datetime columns (e.g. "sdtype": "datetime") , we should convert those columns to the datetime dtype.

Full Stack Trace

ValueError                                Traceback (most recent call last)
Cell In[23], line 3
      1 from sdmetrics.timeseries import LSTMDetection
----> 3 LSTMDetection.compute(
      4     real_data=df1,
      5     synthetic_data=synth_df1,
      6     metadata=metadata1,
      7     sequence_key=['s_key']
      9 )

File [~/.pyenv/versions/sdv_latest/lib/python3.10/site-packages/sdmetrics/timeseries/](http://localhost:8888/lab/tree/issues/github/422/~/.pyenv/versions/sdv_latest/lib/python3.10/site-packages/sdmetrics/timeseries/, in TimeSeriesDetectionMetric.compute(cls, real_data, synthetic_data, metadata, sequence_key)
     81, axis=1))
     83 real_x = cls._build_x(real_data, ht, sequence_key)
---> 84 synt_x = cls._build_x(synthetic_data, ht, sequence_key)
     86 X = pd.concat([real_x, synt_x])
     87 y = pd.Series(np.array([0] * len(real_x) + [1] * len(synt_x)))

File [~/.pyenv/versions/sdv_latest/lib/python3.10/site-packages/sdmetrics/timeseries/](http://localhost:8888/lab/tree/issues/github/422/~/.pyenv/versions/sdv_latest/lib/python3.10/site-packages/sdmetrics/timeseries/, in TimeSeriesDetectionMetric._build_x(data, hypertransformer, sequence_key)
     40 for entity_id, entity_data in data.groupby(sequence_key):
     41     entity_data = entity_data.drop(sequence_key, axis=1)
---> 42     entity_data = hypertransformer.transform(entity_data)
     43     entity_data = pd.Series({
     44         column: entity_data[column].to_numpy()
     45         for column in entity_data.columns
     46     }, name=entity_id)
     48     X = pd.concat([X, pd.DataFrame(entity_data).T], ignore_index=True)

File [~/.pyenv/versions/sdv_latest/lib/python3.10/site-packages/sdmetrics/](http://localhost:8888/lab/tree/issues/github/422/~/.pyenv/versions/sdv_latest/lib/python3.10/site-packages/sdmetrics/, in HyperTransformer.transform(self, data)
    197 elif kind == 'O':
    198     # Categorical column.
    199     col_data = pd.DataFrame({'field': data[field]})
--> 200     out = transform_info['one_hot_encoder'].transform(col_data).toarray()
    201     transformed = pd.DataFrame(
    202         out, columns=[f'value{i}' for i in range(np.shape(out)[1])])
    203     data = data.drop(columns=[field])

File [~/.pyenv/versions/sdv_latest/lib/python3.10/site-packages/sklearn/utils/](http://localhost:8888/lab/tree/issues/github/422/~/.pyenv/versions/sdv_latest/lib/python3.10/site-packages/sklearn/utils/, in _wrap_method_output.<locals>.wrapped(self, X, *args, **kwargs)
    293 @wraps(f)
    294 def wrapped(self, X, *args, **kwargs):
--> 295     data_to_wrap = f(self, X, *args, **kwargs)
    296     if isinstance(data_to_wrap, tuple):
    297         # only wrap the first output for cross decomposition
    298         return_tuple = (
    299             _wrap_data_with_container(method, data_to_wrap[0], X, self),
    300             *data_to_wrap[1:],
    301         )

File [~/.pyenv/versions/sdv_latest/lib/python3.10/site-packages/sklearn/preprocessing/](http://localhost:8888/lab/tree/issues/github/422/~/.pyenv/versions/sdv_latest/lib/python3.10/site-packages/sklearn/preprocessing/, in OneHotEncoder.transform(self, X)
   1018 # validation of X happens in _check_X called by _transform
   1019 warn_on_unknown = self.drop is not None and self.handle_unknown in {
   1020     "ignore",
   1021     "infrequent_if_exist",
   1022 }
-> 1023 X_int, X_mask = self._transform(
   1024     X,
   1025     handle_unknown=self.handle_unknown,
   1026     force_all_finite="allow-nan",
   1027     warn_on_unknown=warn_on_unknown,
   1028 )
   1030 n_samples, n_features = X_int.shape
   1032 if self._drop_idx_after_grouping is not None:

File [~/.pyenv/versions/sdv_latest/lib/python3.10/site-packages/sklearn/preprocessing/](http://localhost:8888/lab/tree/issues/github/422/~/.pyenv/versions/sdv_latest/lib/python3.10/site-packages/sklearn/preprocessing/, in _BaseEncoder._transform(self, X, handle_unknown, force_all_finite, warn_on_unknown, ignore_category_indices)
    208 if handle_unknown == "error":
    209     msg = (
    210         "Found unknown categories {0} in column {1}"
    211         " during transform".format(diff, i)
    212     )
--> 213     raise ValueError(msg)
    214 else:
    215     if warn_on_unknown:

ValueError: Found unknown categories ['1961-05-27', '1909-11-03', '1967-11-28', '1969-08-08', '1918-11-02', '1952-01-24', '1947-12-26', '1981-06-01', '1954-03-04', '1936-11-13'] in column 0 during transform