gsyyysg / StockFormer

PyTorch implementation for Paper "StockFormer: Learning Hybrid Trading Machines with Predictive Coding".
219 stars 54 forks source link

运行train_rl.py时,在train_model这里报错 #5

Open fffanrrr opened 6 months ago

fffanrrr commented 6 months ago

Traceback (most recent call last): File "D:\downloads\StockFormer-main\StockFormer-main\code\train_rl.py", line 210, in trained_sac = agent.train_model(model=model_sac, File "D:\downloads\StockFormer-main\StockFormer-main\code\MySAC\models\DRLAgent.py", line 151, in train_model model = model.learn( File "D:\downloads\StockFormer-main\StockFormer-main\code\MySAC\SAC\MAE_SAC.py", line 364, in learn return super(SAC, self).learn( File "D:\downloads\StockFormer-main\StockFormer-main\code\MySAC\SAC\off_policy_algorithm.py", line 352, in learn rollout = self.collect_rollouts( File "D:\downloads\StockFormer-main\StockFormer-main\code\MySAC\SAC\off_policy_algorithm.py", line 584, in collect_rollouts if callback.on_step() is False: File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\callbacks.py", line 88, in on_step return self._on_step() File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\callbacks.py", line 192, in _on_step continue_training = callback.on_step() and continue_training File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\callbacks.py", line 88, in on_step return self._on_step() File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\callbacks.py", line 379, in _on_step episode_rewards, episode_lengths = evaluate_policy( File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\evaluation.py", line 86, in evaluate_policy observations, rewards, dones, infos = env.step(actions) File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\vec_env\base_vec_env.py", line 163, in step return self.step_wait() File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\vec_env\vec_monitor.py", line 76, in step_wait obs, rewards, dones, infos = self.venv.step_wait() File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\vec_env\dummy_vec_env.py", line 43, in step_wait obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] = self.envs[env_idx].step( File "D:\downloads\StockFormer-main\StockFormer-main\code\envs\env_stocktrading_hybrid_control.py", line 279, in step plt.savefig( File "C:\Users\lenovo\anaconda3\lib\site-packages\matplotlib\pyplot.py", line 1119, in savefig res = fig.savefig(args, kwargs) # type: ignore[func-returns-value] File "C:\Users\lenovo\anaconda3\lib\site-packages\matplotlib\figure.py", line 3390, in savefig self.canvas.print_figure(fname, kwargs) File "C:\Users\lenovo\anaconda3\lib\site-packages\matplotlib\backend_bases.py", line 2193, in print_figure result = print_method( File "C:\Users\lenovo\anaconda3\lib\site-packages\matplotlib\backend_bases.py", line 2043, in print_method = functools.wraps(meth)(lambda args, **kwargs: meth( TypeError: print_png() got an unexpected keyword argument 'index'

fffanrrr commented 6 months ago

请问你们用的matplotlib是什么版本的?

hugo2046 commented 3 months ago

我的方法是注释掉code\envs\env_stocktrading_hybrid_control.py的279,286行的index=False就没问题了

elven2016 commented 3 months ago

我的方法是注释掉code\envs\env_stocktrading_hybrid_control.py的279,286行的index=False就没问题了

我也是这么解决的

trialbox commented 2 months ago

我的方法是注释掉code\envs\env_stocktrading_hybrid_control.py的279,286行的index=False就没问题了

请问您的全跑通了么?我在mae_sac.py里卡住了,temporal_feature_short 和 temporal_feature_long 维度不对,后面self.query_projection(queries)报错。

fffanrrr commented 2 months ago

我跑通了,我看下你报错的截图可以吗

---- 回复的原邮件 ---- | 发件人 | @.> | | 日期 | 2024年05月23日 17:00 | | 收件人 | @.> | | 抄送至 | @.>@.> | | 主题 | Re: [gsyyysg/StockFormer] 运行train_rl.py时,在train_model这里报错 (Issue #5) |

我的方法是注释掉code\envs\env_stocktrading_hybrid_control.py的279,286行的index=False就没问题了

请问您的全跑通了么?我在mae_sac.py里卡住了,temporal_feature_short 和 temporal_feature_long 维度不对,后面self.query_projection(queries)报错。

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>

trialbox commented 2 months ago

我跑通了,我看下你报错的截图可以吗 ---- 回复的原邮件 ---- | 发件人 | @.> | | 日期 | 2024年05月23日 17:00 | | 收件人 | @.> | | 抄送至 | @.>@.> | | 主题 | Re: [gsyyysg/StockFormer] 运行train_rl.py时,在train_model这里报错 (Issue #5) | 我的方法是注释掉code\envs\env_stocktrading_hybrid_control.py的279,286行的index=False就没问题了 请问您的全跑通了么?我在mae_sac.py里卡住了,temporal_feature_short 和 temporal_feature_long 维度不对,后面self.query_projection(queries)报错。 — Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>

Logging to tensorboard_log/mysac/StockFormer/_1 Traceback (most recent call last): File "/media/liruidev/data/WorkDir/Project/trade/StockFormer/code/train_rl.py", line 213, in trained_sac = agent.train_model(model=model_sac, File "/media/liruidev/data/WorkDir/Project/trade/StockFormer/code/MySAC/models/DRLAgent.py", line 153, in train_model model = model.learn( File "/media/liruidev/data/WorkDir/Project/trade/StockFormer/code/MySAC/SAC/MAE_SAC.py", line 379, in learn return super(SAC, self).learn( File "/media/liruidev/data/WorkDir/Project/trade/StockFormer/code/MySAC/SAC/off_policy_algorithm.py", line 354, in learn rollout = self.collect_rollouts( File "/media/liruidev/data/WorkDir/Project/trade/StockFormer/code/MySAC/SAC/off_policy_algorithm.py", line 572, in collect_rollouts action, buffer_action = self._sample_action(learning_starts, action_noise) File "/media/liruidev/data/WorkDir/Project/trade/StockFormer/code/MySAC/SAC/off_policy_algorithm.py", line 412, in _sample_action unscaledaction, = self.predict(self._last_obs, deterministic=False) File "/media/liruidev/data/WorkDir/Project/trade/StockFormer/code/MySAC/SAC/MAE_SAC.py", line 405, in predict state_tensor = self.actor_transformer(obs_tensor, temporal_short, temporal_long, holding) File "/home/liruidev/App/anaconda3/envs/StockFormer/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, kwargs) File "/media/liruidev/data/WorkDir/Project/trade/StockFormer/code/MySAC/SAC/policy_transformer.py", line 35, in forward temporal_hybrid_feature, attn = self.attention( File "/home/liruidev/App/anaconda3/envs/StockFormer/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, *kwargs) File "/media/liruidev/data/WorkDir/Project/trade/StockFormer/code/Transformer/models/attn.py", line 64, in forward queries = self.query_projection(queries).view(B, L, H, -1) File "/home/liruidev/App/anaconda3/envs/StockFormer/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(input, kwargs) File "/home/liruidev/App/anaconda3/envs/StockFormer/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward return F.linear(input, self.weight, self.bias) RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x37 and 128x128)