jdb78 / pytorch-forecasting

Time series forecasting with PyTorch
https://pytorch-forecasting.readthedocs.io/
MIT License
3.75k stars 599 forks source link

Bug of SMAPE when excuting tutorial #1539

Open mayii2001 opened 4 months ago

mayii2001 commented 4 months ago

Expected behavior

I executed code in Tutorial of Demand forecasting with the Temporal Fusion Transformer and expected to get result of SMAPE. The issue appeared while i didn`t change any code from tutorial. I hit the breakpoint but the breakpoint can't break before the error. Solving this problem is beyond my power.

Code to reproduce the problem

# calcualte metric by which to display
predictions = best_tft.predict(val_dataloader,return_y=True)
mean_losses = SMAPE(reduction="none")
mean_losses =mean_losses(predictions.output, predictions.y)
mean_losses =mean_losses.mean(1)
indices = mean_losses.argsort(descending=True)  # sort losses
for idx in range(10):  # plot 10 examples
    best_tft.plot_prediction(
        raw_predictions.x,
        raw_predictions.output,
        idx=indices[idx],
        add_loss_to_title=SMAPE(quantiles=best_tft.loss.quantiles),
    )

Cell In[66], [line 4](vscode-notebook-cell:?execution_count=66&line=4)
      [2](vscode-notebook-cell:?execution_count=66&line=2) predictions = best_tft.predict(val_dataloader,return_y=True)
      [3](vscode-notebook-cell:?execution_count=66&line=3) mean_losses = SMAPE(reduction="none")
----> [4](vscode-notebook-cell:?execution_count=66&line=4) mean_losses =mean_losses(predictions.output, predictions.y)
      [5](vscode-notebook-cell:?execution_count=66&line=5) mean_losses =mean_losses.mean(1)
      [6](vscode-notebook-cell:?execution_count=66&line=6) indices = mean_losses.argsort(descending=True)  # sort losses

File [~.conda\envs\cgm\lib\site-packages\torch\nn\modules\module.py:1501](~.conda/envs/cgm/lib/site-packages/torch/nn/modules/module.py:1501), in Module._call_impl(self, *args, **kwargs)
   [1496](~.conda/envs/cgm/lib/site-packages/torch/nn/modules/module.py:1496) # If we don't have any hooks, we want to skip the rest of the logic in
   [1497](~.conda/envs/cgm/lib/site-packages/torch/nn/modules/module.py:1497) # this function, and just call forward.
   [1498](~.conda/envs/cgm/lib/site-packages/torch/nn/modules/module.py:1498) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1499](~.conda/envs/cgm/lib/site-packages/torch/nn/modules/module.py:1499)         or _global_backward_pre_hooks or _global_backward_hooks
   [1500](~.conda/envs/cgm/lib/site-packages/torch/nn/modules/module.py:1500)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1501](~.conda/envs/cgm/lib/site-packages/torch/nn/modules/module.py:1501)     return forward_call(*args, **kwargs)
   [1502](~.conda/envs/cgm/lib/site-packages/torch/nn/modules/module.py:1502) # Do not call functions when jit is used
   [1503](~.conda/envs/cgm/lib/site-packages/torch/nn/modules/module.py:1503) full_backward_hooks, non_full_backward_hooks = [], []

File [~.conda\envs\cgm\lib\site-packages\torchmetrics\metric.py:303](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:303), in Metric.forward(self, *args, **kwargs)
    [301](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:301)     self._forward_cache = self._forward_full_state_update(*args, **kwargs)
    [302](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:302) else:
--> [303](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:303)     self._forward_cache = self._forward_reduce_state_update(*args, **kwargs)
    [305](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:305) return self._forward_cache

File [~.conda\envs\cgm\lib\site-packages\torchmetrics\metric.py:378](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:378), in Metric._forward_reduce_state_update(self, *args, **kwargs)
    [376](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:376) self._update_count = _update_count + 1
    [377](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:377) with torch.no_grad():
--> [378](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:378)     self._reduce_states(global_state)
    [380](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:380) # restore context
    [381](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:381) self._is_synced = False

File [~.conda\envs\cgm\lib\site-packages\torchmetrics\metric.py:413](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:413), in Metric._reduce_states(self, incoming_state)
    [411](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:411) elif reduce_fn == dim_zero_cat:
    [412](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:412)     if isinstance(global_state, Tensor):
--> [413](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:413)         reduced = torch.cat([global_state, local_state])
    [414](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:414)     else:
    [415](~.conda/envs/cgm/lib/site-packages/torchmetrics/metric.py:415)         reduced = global_state + local_state
RuntimeError: zero-dimensional tensor (at position 0) cannot be concatenated```
MrGG14 commented 1 month ago

Same problem here. Did you solve it? thanks!

TanJingV commented 1 month ago

Same problem here. There should be a problem with the way the library is written in metrics, which is not maintained by the authors.

Jiyeong303 commented 6 days ago

The same error occurred, and I solved it by simply applying the loss function:

mean_losses = SMAPE(reduction="none")
mean_losses = mean_losses.loss(predictions.output, predictions.y)
ci21041 commented 4 days ago

Same error here. Not sure how it works but here are my codes

mean_losses = SMAPE(reduction="none")
mean_losses = mean_losses.loss(predictions.output, predictions.y[0]).mean(1)

Since it seems like the predictions.y is a tuple, and the shape of the predictions.y[0] is the same with the predictions.output