worldbank / REaLTabFormer

A suite of auto-regressive and Seq2Seq (sequence-to-sequence) transformer models for tabular and relational synthetic data generation.
https://worldbank.github.io/REaLTabFormer/
MIT License
200 stars 23 forks source link

Generated datetime value in the child table is invalid #44

Open liu305 opened 1 year ago

liu305 commented 1 year ago

It seems that with smaller output_max_length parameter, the generated sample has invalid datetime values. Instead of datetime strings, they seem to be some numbers the meaning of which are vague. Also sometimes the number is quite long. When running on a small training dataset with about 10k records and output_max_length=4096, I didn't encounter such issue. But with a large training dataset with 1m records and output_max_length=512 (or 1000), this issue occurs. Is it because the output_max_length truncation somehow corrupted the datetime tokens

transactionDateTime 12944829 12863510 12103586 2293176 12628443 12244450294269574381 12538574 18949556 12353260 16880274 9405910 10618463 10250250 872232 2677221 12122979 120384286117778 16718277 11588905

avsolatorio commented 1 year ago

Hello @liu305! In the implementation, we convert datetime values into their equivalent timestamp values. This allows us to model the date as numeric data and generate "new" dates. Unfortunately, a timestamp is represented by a relatively long sequence of integers (~10 digits). So, depending on your data's dimensionality, it could be truncated when you have a shorter output_max_length.

I wonder what your raw datetime values look like. You may use other ways of representing your data. For example, split the year, month, and day into different columns. Then, if the year is constant, you can remove that column. You will just need to do post-processing afterward.

For reference, the following is the specific code for handling datetime values.


def process_datetime_data(
    series, transform_data: Dict = None
) -> Tuple[pd.Series, Dict]:
    # Get the max_len from the current time.
    # This will be ignored later if the actual max_len
    # is shorter.
    max_len = len(str(int(time.time())))

    # Convert the datetimes to
    # their equivalent timestamp values.

    # Make sure that we don't convert the NaT
    # to some integer.
    series = series.copy()
    series.loc[series.notnull()] = (series[series.notnull()].view(int) / 1e9).astype(
        int
    )
    series = series.fillna(pd.NA)

    # Take the mean value to re-align the data.
    # This will help reduce the scale of the numeric
    # data that will need to be generated. Let's just
    # add this offset back later before casting.
    mean_date = None

    if transform_data is None:
        mean_date = int(series.mean())
        series -= mean_date
    else:
        # The mean_date should have been
        # stored during fitting.
        series -= transform_data["mean_date"]

    # Then apply the numeric data processing
    # pipeline.
    series, transform_data = process_numeric_data(
        series,
        max_len=max_len,
        numeric_precision=0,
        transform_data=transform_data,
    )

    # Store the `mean_date` here because `process_numeric_data`
    # expects a None transform_data during fitting.
    if mean_date is not None:
        transform_data["mean_date"] = mean_date

    return series, transform_data```
liu305 commented 1 year ago

Hi @avsolatorio. Thank you very much for your timely response! Correspondingly I have some further questions below which I would appreciate your input.

  1. Probably in my previous experiments I just used some high cardinality variables as they are, which made the vocabulary size so huge! Do you think a smaller vocabulary size would help in this case (so that output_max_length requirement can be relaxed)?
  2. In the same experiment, I also saw that in some other columns of generated data there are invalid values. For example, in the code column I even see name strings, which definitely should belong to another column. Is it also because of the same reason as the invalid datetime values, which is that output_max_length truncation makes things wrong.
  3. Do you think column ordering matters. For example, currently datetime column is the last column. If I move it to the first column, will it be less impacted by the output_max_length?

Best Regards,