sktime / pytorch-forecasting

Time series forecasting with PyTorch
https://pytorch-forecasting.readthedocs.io/
MIT License
3.87k stars 613 forks source link

fitting a single series causing torch to throw 'float' object has no attribute 'shape' #1220

Open wenshutang opened 1 year ago

wenshutang commented 1 year ago

Expected behavior

I'm evaluating with DeepAR following the example using a single series.

Actual behavior

Getting AttributeError: 'float' object has no attribute 'shape' as I try to fit a simple time series.

Code to reproduce the problem

My timeseries pandas dataframe

                 value      timestamp    series
1700         801.0         1700             0
1701             650.0        1701             0
...

value    float64
timestamp       int64
series         object
dtype: object

training = TimeSeriesDataSet(
    df_trips[lambda x: x.index <= train_cutoff],
    time_idx="timestamp",
    target="value",
    # categorical_encoders={"series": NaNLabelEncoder(add_nan=True).fit(df_trips.series)},
    group_ids=['series'],
    # static_categoricals=[
    #     "series"
    # ],
    min_encoder_length=max_encoder_length // 2,
    max_encoder_length=max_encoder_length,
    max_prediction_length=max_prediction_length,
    time_varying_unknown_reals=["value"],
    allow_missing_timesteps=True
)

train_dataloader = training.to_dataloader(
    train=True, batch_size=batch_size, num_workers=0, batch_sampler="synchronized"
)

validation = TimeSeriesDataSet.from_dataset(
    training, 
    data[1700:],
    predict=True,
    min_encoder_length=max_encoder_length // 2
)

Following the example, here is how training step is defined:

early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min")
trainer = pl.Trainer(
    max_epochs=30,
    gpus=0,
    enable_model_summary=True,
    gradient_clip_val=0.1,
    callbacks=[early_stop_callback],
    limit_train_batches=10,
    enable_checkpointing=True,
)

net = DeepAR.from_dataset(
    training,
    learning_rate=0.1,
    log_interval=10,
    log_val_interval=1,
    hidden_size=30,
    rnn_layers=2,
    loss=MultivariateNormalDistributionLoss(rank=30),
)

trainer.fit(
    net,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader
)

I get this error, omitting some of the trace frames:

/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:441: LightningDeprecationWarning: Setting `Trainer(gpus=0)` is deprecated in v1.7 and will be removed in v2.0. Please use `Trainer(accelerator='gpu', devices=0)` instead.
  rank_zero_deprecation(
INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.8/dist-packages/pytorch_lightning/utilities/parsing.py:262: UserWarning: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.
  rank_zero_warn(
/usr/local/lib/python3.8/dist-packages/pytorch_lightning/utilities/parsing.py:262: UserWarning: Attribute 'logging_metrics' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['logging_metrics'])`.
  rank_zero_warn(
INFO:pytorch_lightning.callbacks.model_summary:
  | Name                   | Type                               | Params
------------------------------------------------------------------------------
0 | loss                   | MultivariateNormalDistributionLoss | 0     
1 | logging_metrics        | ModuleList                         | 0     
2 | embeddings             | MultiEmbedding                     | 1     
3 | rnn                    | LSTM                               | 11.6 K
4 | distribution_projector | Linear                             | 992   
------------------------------------------------------------------------------
12.6 K    Trainable params
0         Non-trainable params
12.6 K    Total params
0.051     Total estimated model params size (MB)
Sanity Checking DataLoader 0: 0%
0/1 [10:39<?, ?it/s]
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
[<ipython-input-44-9c3d1dad2a4c>](https://localhost:8080/#) in <module>
     20 )
     21 
---> 22 trainer.fit(
     23     net,
     24     train_dataloaders=train_dataloader,

[/usr/local/lib/python3.8/dist-packages/torch/distributions/transformed_distribution.py](https://localhost:8080/#) in log_prob(self, value)
    144             x = transform.inv(y)
    145             event_dim += transform.domain.event_dim - transform.codomain.event_dim
--> 146             log_prob = log_prob - _sum_rightmost(transform.log_abs_det_jacobian(x, y),
    147                                                  event_dim - transform.domain.event_dim)
    148             y = x

[/usr/local/lib/python3.8/dist-packages/torch/distributions/utils.py](https://localhost:8080/#) in _sum_rightmost(value, dim)
     59     if dim == 0:
     60         return value
---> 61     required_shape = value.shape[:-dim] + (-1,)
     62     return value.reshape(required_shape).sum(-1)
gunewar commented 1 year ago

I have the same issue and no response

wenshutang commented 1 year ago

@gunewar I believe specifying the normalizer when creating the TimeSeriesDataSet fixed it for me

training = TimeSeriesDataSet(
    ...
    target_normalizer=TorchNormalizer(method='identity', center=True, transformation=None, method_kwargs={}),
    add_target_scales=True,
    ...
)
haison19952013 commented 1 year ago

@gunewar I believe specifying the normalizer when creating the TimeSeriesDataSet fixed it for me

training = TimeSeriesDataSet(
    ...
    target_normalizer=TorchNormalizer(method='identity', center=True, transformation=None, method_kwargs={}),
    add_target_scales=True,
    ...
)

This fixed the error for me too..I still don't understand why

haison19952013 commented 1 year ago

@gunewar I believe specifying the normalizer when creating the TimeSeriesDataSet fixed it for me

training = TimeSeriesDataSet(
    ...
    target_normalizer=TorchNormalizer(method='identity', center=True, transformation=None, method_kwargs={}),
    add_target_scales=True,
    ...
)

This fixed the error for me too..I still don't understand why

I guess the error happens when we only have one series...

flixpar commented 6 months ago

I think I found the root of the issue - there's a bug in the log_abs_det_jacobian function (which is used by MultivariateDistributionLoss) for the ReLU target normalizer. I changed line 127 (and 108) in the snippet below from return 0.0 to return torch.tensor(0.0).to(x) so that it returns a tensor instead of a float, and that seemed to solve the problem.

https://github.com/jdb78/pytorch-forecasting/blob/4045054377f71bcf606852c46520c5d7fdf4d0d2/pytorch_forecasting/data/encoders.py#L111-L127