gsyyysg / StockFormer

PyTorch implementation for Paper "StockFormer: Learning Hybrid Trading Machines with Predictive Coding".
219 stars 54 forks source link

测试脚本train_mae.sh中enc_in有错 #14

Closed 15101051 closed 2 months ago

15101051 commented 2 months ago

导致的错误如下:

Traceback (most recent call last):
  File "/home/wenjh/StockFormer/Transformer/main.py", line 108, in <module>
    exp.train(setting)
  File "/home/wenjh/StockFormer/Transformer/exp/exp_mae.py", line 164, in train
    _,_, output = self.model(enc_inp, enc_inp)
  File "/home/wenjh/miniconda3/envs/AP_core_code/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/wenjh/StockFormer/Transformer/models/transformer.py", line 66, in forward
    enc_out = self.enc_embedding(x_enc)
  File "/home/wenjh/miniconda3/envs/AP_core_code/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/wenjh/StockFormer/Transformer/models/embed.py", line 54, in forward
    a = self.value_embedding(x)
  File "/home/wenjh/miniconda3/envs/AP_core_code/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/wenjh/StockFormer/Transformer/models/embed.py", line 40, in forward
    x = self.tokenConv(x.permute(0, 2, 1)).transpose(1,2)
  File "/home/wenjh/miniconda3/envs/AP_core_code/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/wenjh/miniconda3/envs/AP_core_code/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 302, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/wenjh/miniconda3/envs/AP_core_code/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 295, in _conv_forward
    return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
RuntimeError: Given groups=1, weight of size [128, 96, 3], expected input[32, 10, 4] to have 96 channels, but got 10 channels instead

是否应将train_mae.sh中enc_in 96改为10

15101051 commented 2 months ago

补充以下:该脚本中的值为96的项应该都需要更改

15101051 commented 2 months ago

好像是我搞错了,难绷