Open eivistr opened 3 years ago
Although a pragmatic and obviously bad solution, changing line 283 in models.deepar. init from
x[:, 0, target_pos] = lagged_targets[-1]
to
x[:, 0, target_pos] = lagged_targets[-1].to(torch.float32)
fixes the problem temporarily. Might want to find the root cause of the problem and why the tensor has dtype Double in the first place.
Same issue here on an unrelated dataset. It trains but in post-validation (predict_mode) it errors out as above.
I have had the error on two other datasets since the first time as well, appears that it is not isolated to the traffic set.
I think I found (and fixed!) the root cause of this issue: adding the small eps
value to avoid division-by-zero issues with zero standard deviation ended up tricking numpy into promoting 32bit float to 64bit float, which then causes the TorchTransformer
to use a 64-bit float in its calculations (which then makes rescaling a 32bit intput generate a 64bit output and raise the error above).
Here's the fix: https://github.com/jdb78/pytorch-forecasting/pull/795/commits/a010ef9788c6dce4d4cc829ce09bac1d34230639
I have the exact same issue on a different dataset, training/validation with DeepAR works fine but predicting yields the error mentioned above about the type mismatch. Are there plans for the fix above (or something similar) being integrated in the master branch?
Expected behavior
I am basically trying to run the Electricity and Traffic experiments relatively similarly to how they are ran in the DeepAR paper. The model and obtaining forecasts works fine on the electricity set, however, on Traffic I get a type error when getting predictions on the test set related to mismatch between Float and Double values. As a note on the Traffic dataset, values are floats between 0.0 and 1.0 which might be an issue at some point with scaling and converted to a Double?
I have tested the Traffic set with exact same configuration with the Temporal Fusion Transformer and otherwise exact same setup without obtaining this error, it is only when using DeepAR on Traffic.
Problem is presistent on both CPU and GPU.
Code to reproduce the problem
TimeSeriesDataSet:
DeepAR model definition:
Traceback: