worldbank / REaLTabFormer

A suite of auto-regressive and Seq2Seq (sequence-to-sequence) transformer models for tabular and relational synthetic data generation.
https://worldbank.github.io/REaLTabFormer/
MIT License
200 stars 23 forks source link

Bug in process_datetime_data() when using pandas 2.2.1 #71

Closed efstathios-chatzikyriakidis closed 4 months ago

efstathios-chatzikyriakidis commented 4 months ago

Hi @avsolatorio,

I hope you are well.

Unfortunately, I am blocked in a situation where I need to upgrade pandas to latest (2.2.1) but I can't because the library REalTabFormer can't work with it. It seems that the function process_datetime_data() fails. Pandas has deprecated the Series.view() function, and we get both a warning and an error from that line that uses it:

https://github.com/worldbank/REaLTabFormer/blob/4d14472f181f26b68b528630c372c6d828d7aa1e/src/realtabformer/data_utils.py#L259

The good thing is that it is the only place in the code that we use Series.view() so it might be easy to fix.

Can you help me on this? I will need a new PyPI version also (1.0.6). Thanks anyway.

WARNING:

C:\Users\me\.conda\envs\test\lib\site-packages\realtabformer\data_utils.py:259: FutureWarning: Series.view is deprecated and will be removed in a future version. Use ``astype`` as an alternative to change the dtype.
   series.loc[series.notnull()] = (series[series.notnull()].view(int) / 1e9).astype(int)

ERROR:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[61], line 1
----> 1 trainer = parent_model.fit(df=training_source_data_df.reset_index(drop = True),
      2                            n_critic=0,
      3                            device=_get_device())

File ~\.conda\envs\test\lib\site-packages\realtabformer\realtabformer.py:455, in REaLTabFormer.fit(self, df, in_df, join_on, resume_from_checkpoint, device, num_bootstrap, frac, frac_max_data, qt_max, qt_max_default, qt_interval, qt_interval_unique, distance, quantile, n_critic, n_critic_stop, gen_rounds, sensitivity_max_col_nums, use_ks, full_sensitivity, sensitivity_orig_frac_multiple, orig_samples_rounds, load_from_best_mean_sensitivity, target_col)
    453 if self.model_type == ModelType.tabular:
    454     if n_critic <= 0:
***********--> 455         trainer = self._fit_tabular(df, device=device)***********
    456         trainer.train(resume_from_checkpoint=resume_from_checkpoint)
    457     else:

File ~\.conda\envs\test\lib\site-packages\realtabformer\realtabformer.py:1046, in REaLTabFormer._fit_tabular(self, df, device, num_train_epochs, target_epochs)
   1038 def _fit_tabular(
   1039     self,
   1040     df: pd.DataFrame,
   (...)
   1043     target_epochs: int = None,
   1044 ) -> Trainer:
   1045     self._extract_column_info(df)
***********-> 1046     df, self.col_transform_data = process_data(***********
   1047         df,
   1048         numeric_max_len=self.numeric_max_len,
   1049         numeric_precision=self.numeric_precision,
   1050         numeric_nparts=self.numeric_nparts,
   1051         target_col=self.target_col,
   1052     )
   1053     self.processed_columns = df.columns.to_list()
   1054     self.vocab = self._generate_vocab(df)

File ~\.conda\envs\test\lib\site-packages\realtabformer\data_utils.py:486, in process_data(df, numeric_max_len, numeric_precision, numeric_nparts, first_col_type, col_transform_data, target_col)
    483 col_name = encode_processed_column(col_idx[c], ColDataType.DATETIME, c)
    485 _col_transform_data = col_transform_data.get(c)
***********--> 486 series, transform_data = process_datetime_data(***********
    487     df[c],
    488     transform_data=_col_transform_data,
    489 )
    490 if _col_transform_data is None:
    491     # This means that no transform data is available
    492     # before the processing.
    493     col_transform_data[c] = transform_data

File ~\.conda\envs\test\lib\site-packages\realtabformer\data_utils.py:259, in process_datetime_data(series, transform_data)
    253 # Convert the datetimes to
    254 # their equivalent timestamp values.
    255 
    256 # Make sure that we don't convert the NaT
    257 # to some integer.
    258 series = series.copy()
***********--> 259 series.loc[series.notnull()] = (series[series.notnull()].view(int) / 1e9).astype(***********
    260     int
    261 )
    262 series = series.fillna(pd.NA)
    264 # Take the mean value to re-align the data.
    265 # This will help reduce the scale of the numeric
    266 # data that will need to be generated. Let's just
    267 # add this offset back later before casting.

File ~\.conda\envs\test\lib\site-packages\pandas\core\series.py:965, in Series.view(self, dtype)
    962 # self.array instead of self._values so we piggyback on NumpyExtensionArray
    963 #  implementation
    964 res_values = self.array.view(dtype)
--> 965 res_ser = self._constructor(res_values, index=self.index, copy=False)
    966 if isinstance(res_ser._mgr, SingleBlockManager):
    967     blk = res_ser._mgr._block

File ~\.conda\envs\test\lib\site-packages\pandas\core\series.py:575, in Series.__init__(self, data, index, dtype, name, copy, fastpath)
    573     index = default_index(len(data))
    574 elif is_list_like(data):
--> 575     com.require_length_match(data, index)
    577 # create/copy the manager
    578 if isinstance(data, (SingleBlockManager, SingleArrayManager)):

File ~\.conda\envs\test\lib\site-packages\pandas\core\common.py:573, in require_length_match(data, index)
    569 """
    570 Check the length of data matches the length of the index.
    571 """
    572 if len(data) != len(index):
--> 573     raise ValueError(
    574         "Length of values "
    575         f"({len(data)}) "
    576         "does not match length of index "
    577         f"({len(index)})"
    578     )

ValueError: Length of values (22318) does not match length of index (11159)

Here is the list of python packages I am using:

numpy==1.22.4
pandas==2.2.1
multiprocess==0.70.14
dill==0.3.6
transformers==4.27.4
REaLTabFormer==0.1.5
psycopg2==2.9.6
SQLAlchemy==2.0.12
pydantic==1.10.7
jsonschema==4.17.3
efstathios-chatzikyriakidis commented 4 months ago

Hi @avsolatorio!

Unfortunately, I still have a problem with the fix. I think that we need to do .astype('int64') instead of .astype(int). It is safer to convert datetime64[ns] to int64 because in some systems bare int could translate to int32.

Line that needs to be changed, from:

series = (series.astype(int) / 1e9)

To:

series = (series.astype('int64') / 1e9)

The fix will allow to run it everywhere with latest pandas.