nok-halfspace / Transformer-Time-Series-Forecasting

373 stars 102 forks source link

Getting an error when trying to run #12

Open aarshcode opened 1 year ago

aarshcode commented 1 year ago
RuntimeError: Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: float and  query.dtype: double instead.

Here's the complete call stack

Traceback (most recent call last):
  File "/Users/aarsh/Code/MajorProject/Transformer-Time-Series-Forecasting/main.py", line 48, in <module>
    main(
  File "/Users/aarsh/Code/MajorProject/Transformer-Time-Series-Forecasting/main.py", line 33, in main
    best_model = transformer(train_dataloader, epoch, k, frequency, path_to_save_model, path_to_save_loss, path_to_save_predictions, device)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aarsh/Code/MajorProject/Transformer-Time-Series-Forecasting/train_with_sampling.py", line 50, in transformer
    prediction = model(sampled_src, device) # torch.Size([1xw, 1, 1])
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aarsh/Code/JupyterLab/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aarsh/Code/MajorProject/Transformer-Time-Series-Forecasting/model.py", line 33, in forward
    output = self.transformer_encoder(src,mask)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aarsh/Code/JupyterLab/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aarsh/Code/JupyterLab/venv/lib/python3.11/site-packages/torch/nn/modules/transformer.py", line 306, in forward
    output = mod(output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aarsh/Code/JupyterLab/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aarsh/Code/JupyterLab/venv/lib/python3.11/site-packages/torch/nn/modules/transformer.py", line 573, in forward
    x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aarsh/Code/JupyterLab/venv/lib/python3.11/site-packages/torch/nn/modules/transformer.py", line 581, in _sa_block
    x = self.self_attn(x, x, x,
        ^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aarsh/Code/JupyterLab/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aarsh/Code/JupyterLab/venv/lib/python3.11/site-packages/torch/nn/modules/activation.py", line 1189, in forward
    attn_output, attn_output_weights = F.multi_head_attention_forward(
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aarsh/Code/JupyterLab/venv/lib/python3.11/site-packages/torch/nn/functional.py", line 5334, in multi_head_attention_forward
    attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: float and  query.dtype: double instead.
hxy-tsy commented 1 year ago

RuntimeError: Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: float and query.dtype: double instead.

hxy-tsy commented 1 year ago
RuntimeError: Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: float and  query.dtype: double instead.

Here's the complete call stack

Traceback (most recent call last):
  File "/Users/aarsh/Code/MajorProject/Transformer-Time-Series-Forecasting/main.py", line 48, in <module>
    main(
  File "/Users/aarsh/Code/MajorProject/Transformer-Time-Series-Forecasting/main.py", line 33, in main
    best_model = transformer(train_dataloader, epoch, k, frequency, path_to_save_model, path_to_save_loss, path_to_save_predictions, device)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aarsh/Code/MajorProject/Transformer-Time-Series-Forecasting/train_with_sampling.py", line 50, in transformer
    prediction = model(sampled_src, device) # torch.Size([1xw, 1, 1])
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aarsh/Code/JupyterLab/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aarsh/Code/MajorProject/Transformer-Time-Series-Forecasting/model.py", line 33, in forward
    output = self.transformer_encoder(src,mask)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aarsh/Code/JupyterLab/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aarsh/Code/JupyterLab/venv/lib/python3.11/site-packages/torch/nn/modules/transformer.py", line 306, in forward
    output = mod(output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aarsh/Code/JupyterLab/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aarsh/Code/JupyterLab/venv/lib/python3.11/site-packages/torch/nn/modules/transformer.py", line 573, in forward
    x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aarsh/Code/JupyterLab/venv/lib/python3.11/site-packages/torch/nn/modules/transformer.py", line 581, in _sa_block
    x = self.self_attn(x, x, x,
        ^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aarsh/Code/JupyterLab/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aarsh/Code/JupyterLab/venv/lib/python3.11/site-packages/torch/nn/modules/activation.py", line 1189, in forward
    attn_output, attn_output_weights = F.multi_head_attention_forward(
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/aarsh/Code/JupyterLab/venv/lib/python3.11/site-packages/torch/nn/functional.py", line 5334, in multi_head_attention_forward
    attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: float and  query.dtype: double instead.

你解决这个问题没有阿?