gsyyysg / StockFormer

PyTorch implementation for Paper "StockFormer: Learning Hybrid Trading Machines with Predictive Coding".
219 stars 54 forks source link

回测时加载模型checkpoint和初始化actor critic维度不同 #13

Open wuxiawei opened 2 months ago

wuxiawei commented 2 months ago

在执行: results = DRLAgent.DRL_prediction_load_from_file(model_name='maesac',environment=test_trade_gym, cwd=model_path) 的时候报错: RuntimeError: Error(s) in loading state_dict for SACPolicy: size mismatch for actor.mu.weight: copying a param with shape torch.Size([88, 256]) from checkpoint, the shape in current model is torch.Size([31064, 256]). size mismatch for actor.mu.bias: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([31064]). size mismatch for actor.log_std.weight: copying a param with shape torch.Size([88, 256]) from checkpoint, the shape in current model is torch.Size([31064, 256]). size mismatch for actor.log_std.bias: copying a param with shape torch.Size([88]) from checkpoint, the shape in current model is torch.Size([31064]). size mismatch for critic.qf0.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]). size mismatch for critic.qf1.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]). size mismatch for critic_target.qf0.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]). size mismatch for critic_target.qf1.0.weight: copying a param with shape torch.Size([256, 11440]) from checkpoint, the shape in current model is torch.Size([256, 42416]). 请问这是什么原因呢?

xbkaishui commented 2 months ago

请问这个问题解决了吗?

wuxiawei commented 2 months ago

还没有,你现在解决了吗?

yo-yoo commented 1 month ago

请问这个问题解决了吗?

遇到了同样的问题,请问解决了吗?是因为单卡运行导致的吗?