Open erl61 opened 1 year ago
Hi @erl61, have you figured out any solution?
@0xleowang No, I use accelerator='cpu'
having the same issue
@0xleowang No, I use accelerator='cpu'
Did you fix that?
Hi, I noticed that timeseries.py is using torch.int64
for some tensors. I was able to avoid this issue by using torch.int32
. Hope this helps somehow.
Expected behavior
I am trying to execute "Demand forecasting with the Temporal Fusion Transformer" tutorial using accelerator='mps'.
Actual behavior
However, I get "RuntimeError: index_add(): Expected non int64 dtype for source." when identifying the optimal learning rate. The same error when I am trying to train TFT model. I define "%env PYTORCH_ENABLE_MPS_FALLBACK=1" as the first line of the code to use 'mps'.
Code to reproduce the problem
Traceback: