jdb78 / pytorch-forecasting

Time series forecasting with PyTorch
https://pytorch-forecasting.readthedocs.io/
MIT License
3.87k stars 611 forks source link

Dataloader generated by TimeSeriesDataSet Not Normalizing Target #1339

Open YojoNick opened 1 year ago

YojoNick commented 1 year ago

This library is awesome, a big thank you for all of your hard work for creating something that's super useful! May I please have your help with the following?

tl-dr: returned target from data loader isn't normalized even though I specified the standard normalizer for the associated TimeSeriesDataSet target_normalizer

Expected behavior

I executed code to

Actual behavior

However, result was:

Code to reproduce the problem

# Imports and parameters
import pandas as pd
import numpy as np
from datetime import date
import datetime

from pytorch_forecasting import TimeSeriesDataSet
from pytorch_forecasting.data.encoders import TorchNormalizer, \
    NaNLabelEncoder, EncoderNormalizer

numGroups = 2
numSamplesPerGroup = 12

# Step 1:  Create Pandas Dataframe containing generated sample data
columns = {'time_idx':pd.Series(dtype='int'),
           'GroupID':pd.Series(dtype='int'),
           'Date':pd.Series(dtype='int'),
           'DayOfWeek':pd.Series(dtype='int'),
           'DayOfMonth':pd.Series(dtype='int'),
           'Month':pd.Series(dtype='int'),
           'Week':pd.Series(dtype='int'),
           'RealVal1':pd.Series(dtype='float'),
           'RealVal2':pd.Series(dtype='float'),
           'Target':pd.Series(dtype='float')}

testDataGroup = pd.DataFrame(columns=columns)
testDataFrame = pd.DataFrame(columns=columns)

testDataGroup['time_idx'] = pd.Series(np.arange(1,numSamplesPerGroup + 1).astype(int), dtype='int')
testDataGroup['Date'] = pd.DataFrame(pd.date_range(date(2023, 1, 1), periods=numSamplesPerGroup, freq='D'))
testDataGroup['DayOfWeek'] = testDataGroup['Date'].dt.dayofweek.astype(str).astype("category")
testDataGroup['DayOfMonth'] = testDataGroup['Date'].dt.day.astype(str).astype("category")
testDataGroup['Month'] = testDataGroup['Date'].dt.month.astype(str).astype("category")
testDataGroup['Week'] = testDataGroup['Date'].dt.isocalendar().week.astype(str).astype("category")
testDataGroup['RealVal1'] = 2.0 * testDataGroup['time_idx']
testDataGroup['RealVal2'] = 3.0 * testDataGroup['time_idx']
testDataGroup['Target'] = 4.0 * testDataGroup['time_idx']

for grpItr in np.arange(numGroups):
    testDataGroup['GroupID'] = grpItr
    testDataFrame = pd.concat([testDataFrame, testDataGroup])

testDataFrame.index = np.arange(0,testDataFrame.shape[0])

testDataFrame['time_idx'] = pd.to_numeric(testDataFrame['time_idx'])
testDataFrame.index = pd.to_numeric(testDataFrame.index)

# Step 2: Create Time SeriesDataSet and dataloader and return single sample from dataloader
encoderLength = 4
predictionLength=1
timeVaryKnwnCats=['DayOfWeek','DayOfMonth','Month','Week']
timeVaryKnwnReals=['time_idx']
timeVaryUnknwn=['RealVal1','RealVal2','Target']

nanCatEnc = {}
for enc in timeVaryKnwnCats:
    nanCatEnc[enc] = NaNLabelEncoder(add_nan=True)

scalers = {}
for feat in testDataFrame.columns:
    if feat != 'Target':
        scalers[feat] = None

tsd = TimeSeriesDataSet(
                testDataFrame,
                time_idx="time_idx",
                target='Target',
                group_ids=["GroupID"],
                min_encoder_length=encoderLength,
                max_encoder_length=encoderLength,
                min_prediction_length=predictionLength, 
                max_prediction_length=predictionLength,
                static_categoricals=['GroupID'],
                time_varying_known_categoricals=timeVaryKnwnCats,
                time_varying_known_reals=timeVaryKnwnReals,
                time_varying_unknown_categoricals=[],
                time_varying_unknown_reals=timeVaryUnknwn,
                target_normalizer=TorchNormalizer(method='standard'),
                add_relative_time_idx=True,
                add_target_scales=True,
                add_encoder_length=True,
                allow_missing_timesteps=True,
                categorical_encoders=nanCatEnc,
                scalers=scalers,
                predict_mode=False)

dataloader = tsd.to_dataloader(batch_size=2)

x,y = next(iter(dataloader))

# y returned sample is NOT normalized even though target_normalizer was specified with standard normalizer
manitadayon commented 1 year ago

Can you change it to GroupNormalizer or something else to see if this is normalized?

YojoNick commented 1 year ago

Changing the target_normalizer parameter to GroupNormalizer(groups=['GroupID'], transformation="softplus") still results in the target value being equal to its original, unnormalized value.

manitadayon commented 1 year ago

It is strange, which suggests maybe the x,y = next(iter(dataloader)) is not a right way to see if the features are normalized or not. I experiment with normalization through different configuration and it certainly impact my final result and prediction. Try to look at the following: tsd.dataset.data (I think it is something like this, I am in front of computer to try this), it should tell all the information needed.

YojoNick commented 1 year ago

FYI - The values in tsd.data['target'] are all also not normalized

manitadayon commented 1 year ago

I think tsd.data['target'] returns the raw values.

badiaog commented 1 year ago

I have the same problem and I would like to be able to use normalized predicted values and normalized target values when calculating losses. But the normal output is unnormalized for both predicted and target values. So I changed the source code.for example,temporal_fusion_transformer.init.py In the forward method:

        if self.n_targets > 1:  # if to use multi-target architecture
            output = [output_layer(output) for output_layer in self.output_layer]
        else:
            output = self.output_layer(output)

        return self.to_network_output(
            **scaled_prediction = output,**   # output as scaled_prediction before inverse normalization
            prediction=self.transform_output(output, target_scale=x["target_scale"]),
            encoder_attention=attn_output_weights[..., :max_encoder_length],
            decoder_attention=attn_output_weights[..., max_encoder_length:],
            static_variables=static_variable_selection,
            encoder_variables=encoder_sparse_weights,
            decoder_variables=decoder_sparse_weights,
            decoder_lengths=decoder_lengths,
            encoder_lengths=encoder_lengths,
        )

In addition I haven't found a suitable method for outputting the true y_true in normalized form yet. I just wrote a method to normalize the true target value in each batch, but this is very time consuming.

    def scale_y_true(batch_gpu, df_Pfizer, normalizer, device):
        my_target = batch_gpu['decoder_target']
        arr = []
        for one_true in my_target:
            sss = pd.Series(one_true.tolist())
            arr.append(normalizer.transform(sss, df_Pfizer.iloc[:max_prediction_length,:]))
        scaled_y_true = torch.tensor(np.vstack(arr)).to(device)
        return scaled_y_true
manitadayon commented 1 year ago

@jdb78 Can you please verify if normalization works as expected?

manitadayon commented 1 year ago

So can you check dataloader .dataset.data['reals'], this should be normalized.

badiaog commented 1 year ago

So can you check dataloader .dataset.data['reals'], this should be normalized.

The input seems to be normalized, but the normalized y (target vector) is not found in the data.

# create sale dataset and dataloaders
training_sale = TimeSeriesDataSet(
    df_Pfizer[lambda x: x.time_idx < training_cutoff],
    group_ids=["group"],
    target='sale',
    time_idx="time_idx",
    min_encoder_length=max_encoder_length,  # keep encoder length long (as it is in the validation set)
    max_encoder_length=max_encoder_length,
    min_prediction_length=max_prediction_length,
    max_prediction_length=max_prediction_length,
    time_varying_unknown_reals=['sale']  + (sale_cols),
    time_varying_known_reals=["time_idx", "cos_doy", "sin_doy", "is_workday", "is_holiday"],
    # time_varying_known_reals=["time_idx"],
    time_varying_known_categoricals=[],
    target_normalizer=GroupNormalizer(
        groups=["group"], transformation="softplus"
    ),

    # use softplus and normalize by group, SoftPlus is a smooth approximation to the ReLU function and can be used to constrain the output of a machine to always be positive.
    add_relative_time_idx=False,
    add_target_scales=True,
    # add_encoder_length=True,
)

# create validation dataset using the same normalization techniques as for the training dataset
validation_sale = TimeSeriesDataSet.from_dataset(training_sale, df_Pfizer[
    lambda x: (x.time_idx > training_cutoff - max_encoder_length - max_prediction_length) & (
            x.time_idx < validation_cutoff)], stop_randomization=True)

# create test dataset using the same normalization techniques as for the training dataset
test_sale = TimeSeriesDataSet.from_dataset(training_sale, df_Pfizer[
    lambda x: x.time_idx > validation_cutoff - max_encoder_length - max_prediction_length],
                                           stop_randomization=True)
print(test_sale)
BATCH_SIZE = 64
# convert datasets to dataloaders for training
train_sale_dataloader = training_sale.to_dataloader(train=True, batch_size=BATCH_SIZE, shuffle=False)
val_sale_dataloader = validation_sale.to_dataloader(train=False, batch_size=BATCH_SIZE, shuffle=False)
test_sale_dataloader = test_sale.to_dataloader(train=False, batch_size=BATCH_SIZE, shuffle=False)
# create validation set (predict=True) which means to predict the last max_prediction_length points in time
# for each series
for batch in test_sale_dataloader:
    print(batch)
    break

image image

manitadayon commented 1 year ago

What is the dimension of your sale_cols?

badiaog commented 1 year ago

What is the dimension of your sale_cols?

image These are some of the time-sliding statistics generated based on the 'sale' column.

manitadayon commented 1 year ago

I think the reals should already contain the target as well, since you include it as unknown_reals.

badiaog commented 1 year ago

I think the reals should already contain the target as well, since you include it as unknown_reals.

image In line 476 of timeseries.py, self._preprocess_data(data) returns a dataframe in which, for example, my 'leadtime' column is my target, and its normalized state. But in this dataframe there is another column 'targetleadtime', which seems to be the one returned in .dataset['target']. Then after line 484 self._data_to_tensors(data) will convert this dataframe to tensors. When you go to this method, you can see that the target is processed and what is returned is the column 'targetleadtime' or data[f"__target__{self.target}"].

        # get target
        if isinstance(self.target_normalizer, NaNLabelEncoder):
            target = [
                check_for_nonfinite(
                    torch.tensor(data[f'__target__{self.target}'].to_numpy(dtype=np.int64), dtype=torch.long),
                    self.target,
                )
            ]
        else:
            if not isinstance(self.target, str):  # multi-target
                target = [
                    check_for_nonfinite(
                        torch.tensor(
                            data[f"__target__{name}"].to_numpy(
                                dtype=[np.float64, np.int64][data[name].dtype.kind in "bi"]
                            ),
                            dtype=[torch.float, torch.long][data[name].dtype.kind in "bi"],
                        ),
                        name,
                    )
                    for name in self.target_names
                ]
            else:
                target = [
                    check_for_nonfinite(
                        torch.tensor(data[f"__target__{self.target}"].to_numpy(dtype=np.float64), dtype=torch.float),
                        self.target,
                    )
                ]

And what we need is the normalized target value, which is the data in the 'leadtime' column, or the data[f"{self.target}"] column. Maybe we can follow this method and copy and change it and return the normalized target I need. The returned dataset dictionary would then contain 'normalized_target'.

        # get target
        if isinstance(self.target_normalizer, NaNLabelEncoder):
            target = [
                check_for_nonfinite(
                    torch.tensor(data[f"__target__{self.target}"].to_numpy(dtype=np.int64), dtype=torch.long),
                    self.target,
                )
            ]
        else:
            if not isinstance(self.target, str):  # multi-target
                target = [
                    check_for_nonfinite(
                        torch.tensor(
                            data[f"__target__{name}"].to_numpy(
                                dtype=[np.float64, np.int64][data[name].dtype.kind in "bi"]
                            ),
                            dtype=[torch.float, torch.long][data[name].dtype.kind in "bi"],
                        ),
                        name,
                    )
                    for name in self.target_names
                ]
            else:
                target = [
                    check_for_nonfinite(
                        torch.tensor(data[f"__target__{self.target}"].to_numpy(dtype=np.float64), dtype=torch.float),
                        self.target,
                    )
                ]
        # get normalized target
        if isinstance(self.target_normalizer, NaNLabelEncoder):
            normalized_target = [
                check_for_nonfinite(
                    torch.tensor(data[f"{self.target}"].to_numpy(dtype=np.int64), dtype=torch.long),
                    self.target,
                )
            ]
        else:
            if not isinstance(self.target, str):  # multi-target
                normalized_target = [
                    check_for_nonfinite(
                        torch.tensor(
                            data[f"{name}"].to_numpy(
                                dtype=[np.float64, np.int64][data[name].dtype.kind in "bi"]
                            ),
                            dtype=[torch.float, torch.long][data[name].dtype.kind in "bi"],
                        ),
                        name,
                    )
                    for name in self.target_names
                ]
            else:
                normalized_target = [
                    check_for_nonfinite(
                        torch.tensor(data[f"{self.target}"].to_numpy(dtype=np.float64), dtype=torch.float),
                        self.target,
                    )
                ]
        # continuous covariates
        continuous = check_for_nonfinite(
            torch.tensor(data[self.reals].to_numpy(dtype=np.float64), dtype=torch.float), self.reals
        )

        tensors = dict(
            reals=continuous, categoricals=categorical, groups=index, target=target,
            normalized_target=normalized_target, weight=weight, time=time
        )

        return tensors

Then change the getitem method in timeseries.py to add the returned data['normalized_target']

        target = [d[index.index_start: index.index_end + 1].clone() for d in self.data["target"]]
        groups = self.data["groups"][index.index_start].clone()
        normalized_target = [d[index.index_start: index.index_end + 1].clone() for d in self.data["normalized_target"]]
        if self.multi_target:
            encoder_target = [t[:encoder_length] for t in target]
            target = [t[encoder_length:] for t in target]
            normalized_target = [t[encoder_length:] for t in normalized_target]
        else:
            encoder_target = target[0][:encoder_length]
            target = target[0][encoder_length:]
            normalized_target = normalized_target[0][encoder_length:]
            target_scale = target_scale[0]

        return (
            dict(
                x_cat=data_cat,
                x_cont=data_cont,
                encoder_length=encoder_length,
                decoder_length=decoder_length,
                encoder_target=encoder_target,
                encoder_time_idx_start=time[0],
                groups=groups,
                target_scale=target_scale,
            ),
            (target, weight, normalized_target), #I put normalized in the position of returning batches[1][2].
        )

Then in _collate_fn() method:

        # target and weight
        if isinstance(batches[0][1][0], (tuple, list)):
            target = [
                rnn.pad_sequence([batch[1][0][idx] for batch in batches], batch_first=True)
                for idx in range(len(batches[0][1][0]))
            ]
            normalized_target = [
                rnn.pad_sequence([batch[1][2][idx] for batch in batches], batch_first=True)
                for idx in range(len(batches[0][1][0]))
            ]
            encoder_target = [
                rnn.pad_sequence([batch[0]["encoder_target"][idx] for batch in batches], batch_first=True)
                for idx in range(len(batches[0][1][0]))
            ]
        else:
            target = rnn.pad_sequence([batch[1][0] for batch in batches], batch_first=True)
            normalized_target = rnn.pad_sequence([batch[1][2] for batch in batches], batch_first=True)
            encoder_target = rnn.pad_sequence([batch[0]["encoder_target"] for batch in batches], batch_first=True)

        if batches[0][1][1] is not None:
            weight = rnn.pad_sequence([batch[1][1] for batch in batches], batch_first=True)
        else:
            weight = None

        return (
            dict(
                encoder_cat=encoder_cat,
                encoder_cont=encoder_cont,
                encoder_target=encoder_target,
                encoder_lengths=encoder_lengths,
                decoder_cat=decoder_cat,
                decoder_cont=decoder_cont,
                decoder_target=target,
                decoder_normalized_target = normalized_target, **# We can get this by batch['decoder_normalized_target']**
                decoder_lengths=decoder_lengths,
                decoder_time_idx=decoder_time_idx,
                groups=groups,
                target_scale=target_scale,
            ),
            (target, weight), #Or  put the normalized_target here
        )

After doing the above, we can use batch['decoder_normalized_target'] to get the normalized target. image I think this should work, but I don't know if there is a more convenient way to get the normalized target directly, please let me know.

JustusMzB commented 11 months ago

After thorough debugging, I came to the conclusion that the target and all reals are in fact normalized before being fed into the various networks. However, before loss calculation, they are denormalized using the data["target_scale"] entry of the input dictionary. I found no way to disable this. I guess one way of receiving the normalized target would be to use the data["target_scale"] values to re-normalize the target.

Example: Excerpt from TemporalFusionTransformer.forward:

# ...
 return self.to_network_output(
            prediction=self.transform_output(output, target_scale=x["target_scale"]),
            encoder_attention=attn_output_weights[..., :max_encoder_length],
            decoder_attention=attn_output_weights[..., max_encoder_length:],
            static_variables=static_variable_selection,
            encoder_variables=encoder_sparse_weights,
            decoder_variables=decoder_sparse_weights,
            decoder_lengths=decoder_lengths,
            encoder_lengths=encoder_lengths,
        )

Note the use of self.transform_output. This is a method in `BaseModel``:

     def transform_output(
        self,
        prediction: Union[torch.Tensor, List[torch.Tensor]],
        target_scale: Union[torch.Tensor, List[torch.Tensor]],
        loss: Optional[Metric] = None,
    ) -> torch.Tensor:
        """
        Extract prediction from network output and rescale it to real space / de-normalize it.

        Args:
            prediction (Union[torch.Tensor, List[torch.Tensor]]): normalized prediction
            target_scale (Union[torch.Tensor, List[torch.Tensor]]): scale to rescale prediction
            loss (Optional[Metric]): metric to use for transform

        Returns:
            torch.Tensor: rescaled prediction
        """
# ...

Only after this is loss calculated. That is also why the y in the dataset returns the non-normalized target.

I don't think this makes particular sense, normalization is a step that is also important to limit the impact of outliers, especially when using multiple time series for training. I have not found a solution to this yet.