worldbank / REaLTabFormer

A suite of auto-regressive and Seq2Seq (sequence-to-sequence) transformer models for tabular and relational synthetic data generation.
https://worldbank.github.io/REaLTabFormer/
MIT License
200 stars 23 forks source link

Possible mix-up of token columns #47

Closed liu305 closed 6 months ago

liu305 commented 1 year ago

Hello. In the parent model training during the validation phase when the model is generating synthetic data to compare with real raw data, the training is somehow terminated in the middle (usually around 45-50 epochs) by ValueError due to invalid string representation of float numbers. For example, I see something like "ValueError: could not convert string to float: '-0.-0.2'". I checked the model vocabulary and everything seems to be fine. For example, column_00 has tokens of "-0.", "01", -3.", and column_01 has tokens of "123","532","324". If for each processed column the model only generates tokens specific to that column, then why we would see '-0.-0.2'. It seems that tokens in column_00 is sampled for the location of column_01. Do you have any idea where the things could go wrong in the source code so that this issue could happen?

avsolatorio commented 1 year ago

Hello @liu305, thanks for sharing this! Could you please share a traceback of the error so I can see where in the code this is happening?

liu305 commented 1 year ago

Hello. I did some debugging myself. After uncommenting line 441 of rtf_sampler.py, I no longer see such issues. So it seems that some invalid samples got leaked into the final stage and thus gave invalid string representation for a float number.

FYI, the trackback of the original error looks like this

Traceback (most recent call last): File "/iter3/debugging/02_parent_model_training.py", line 17, in parent_model.fit( File "/iter3/debugging/realtabformer_local/realtabformer.py", line 457, in fit trainer = self._train_with_sensitivity( File "/iter3/debugging/realtabformer_local/realtabformer.py", line 751, in _train_with_sensitivity SyntheticDataBench.compute_sensitivity_metric( File "/iter3/debugging/realtabformer_local/rtf_analyze.py", line 526, in compute_sensitivity_metric processed = SyntheticDataBench.preprocess_data( File "/iter3/debugging/realtabformer_local/rtf_analyze.py", line 254, in preprocess_data _other = preprocessor.transform(_other) File "/miniconda3/user-envs/liu305/realtabformer-env/lib/python3.9/site-packages/sklearn/utils/_set_output.py", line 140, in wrapped data_to_wrap = f(self, X, *args, kwargs) File "/miniconda3/user-envs/liu305/realtabformer-env/lib/python3.9/site-packages/sklearn/compose/_column_transformer.py", line 816, in transform Xs = self._fit_transform( File "/miniconda3/user-envs/liu305/realtabformer-env/lib/python3.9/site-packages/sklearn/compose/_column_transformer.py", line 670, in _fit_transform return Parallel(n_jobs=self.n_jobs)( File "/miniconda3/user-envs/liu305/realtabformer-env/lib/python3.9/site-packages/sklearn/utils/parallel.py", line 65, in call return super().call(iterable_with_config) File "/miniconda3/user-envs/liu305/realtabformer-env/lib/python3.9/site-packages/joblib/parallel.py", line 1855, in call return output if self.return_generator else list(output) File "/miniconda3/user-envs/liu305/realtabformer-env/lib/python3.9/site-packages/joblib/parallel.py", line 1784, in _get_sequential_output res = func(*args, *kwargs) File "/miniconda3/user-envs/liu305/realtabformer-env/lib/python3.9/site-packages/sklearn/utils/parallel.py", line 127, in call return self.function(args, kwargs) File "/miniconda3/user-envs/liu305/realtabformer-env/lib/python3.9/site-packages/sklearn/pipeline.py", line 933, in _transform_one res = transformer.transform(X) File "/miniconda3/user-envs/liu305/realtabformer-env/lib/python3.9/site-packages/sklearn/utils/_set_output.py", line 140, in wrapped data_to_wrap = f(self, X, *args, kwargs) File "/miniconda3/user-envs/liu305/realtabformer-env/lib/python3.9/site-packages/sklearn/preprocessing/_data.py", line 1004, in transform X = self._validate_data( File "/miniconda3/user-envs/liu305/realtabformer-env/lib/python3.9/site-packages/sklearn/base.py", line 604, in _validate_data out = check_array(X, input_name="X", check_params) File "/miniconda3/user-envs/liu305/realtabformer-env/lib/python3.9/site-packages/sklearn/utils/validation.py", line 917, in check_array array = _asarray_with_order(array, order=order, dtype=dtype, xp=xp) File "/miniconda3/user-envs/liu305/realtabformer-env/lib/python3.9/site-packages/sklearn/utils/_array_api.py", line 380, in _asarray_with_order array = numpy.asarray(array, order=order, dtype=dtype) File "/miniconda3/user-envs/liu305/realtabformer-env/lib/python3.9/site-packages/pandas/core/generic.py", line 2070, in array return np.asarray(self._values, dtype=dtype) ValueError: could not convert string to float: '2.0.20'