TradeMaster-NTU / TradeMaster

TradeMaster is an open-source platform for quantitative trading empowered by reinforcement learning :fire: :zap: :rainbow:
Apache License 2.0
1.35k stars 273 forks source link

Shape error of the actor forward function in DeepTrader task #158

Closed Gikiman closed 1 year ago

Gikiman commented 1 year ago
import os
import sys
from pathlib import Path
import warnings
warnings.filterwarnings("ignore")
ROOT = os.path.dirname(os.path.abspath("."))
sys.path.append(ROOT)
import torch
import argparse
import os.path as osp
from mmcv import Config
from trademaster.utils import replace_cfg_vals
from trademaster.nets.builder import build_net
from trademaster.environments.builder import build_environment
from trademaster.datasets.builder import build_dataset
from trademaster.agents.builder import build_agent
from trademaster.optimizers.builder import build_optimizer
from trademaster.losses.builder import build_loss
from trademaster.trainers.builder import build_trainer
from trademaster.transition.builder import build_transition
from trademaster.utils import plot
from trademaster.utils import set_seed
set_seed(2023)

parser = argparse.ArgumentParser(description='Download dj30 Datasets')
parser.add_argument("--config", default=osp.join(ROOT, "configs", "portfolio_management", "portfolio_management_dj30_deeptrader_deeptrader_adam_mse.py"),
                    help="download datasets config file path")
parser.add_argument("--task_name", type=str, default="train")

args, _= parser.parse_known_args()
cfg = Config.fromfile(args.config)
task_name = args.task_name
cfg = replace_cfg_vals(cfg)

dataset = build_dataset(cfg)

train_environment = build_environment(cfg, default_args=dict(dataset=dataset, task="train"))
valid_environment = build_environment(cfg, default_args=dict(dataset=dataset, task="valid"))
test_environment = build_environment(cfg, default_args=dict(dataset=dataset, task="test"))

action_dim = train_environment.action_dim # 29
state_dim = train_environment.state_dim # 16
input_dim = len(train_environment.tech_indicator_list)

act = build_net(cfg.act_net)
cri = build_net(cfg.cri_net)
market = build_net(cfg.market_net)
act_optimizer = build_optimizer(cfg, default_args=dict(params=act.parameters()))
cri_optimizer = build_optimizer(cfg, default_args=dict(params=cri.parameters()))
market_optimizer = build_optimizer(cfg, default_args=dict(params=market.parameters()))
criterion = build_loss(cfg)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
agent = build_agent(cfg, default_args=dict(action_dim=action_dim, 
                        state_dim=state_dim,act=act,cri=cri,
                        market=market,act_optimizer=act_optimizer, \
                        cri_optimizer = cri_optimizer, criterion = criterion,\
                        market_optimizer = market_optimizer,device = device))

trainer = build_trainer(cfg,
                        default_args=dict(train_environment=train_environment,
                                          valid_environment=valid_environment,
                                          test_environment=test_environment,
                                          agent=agent,device=device))
work_dir = os.path.join(ROOT, cfg.trainer.work_dir)

if not os.path.exists(work_dir):
    os.makedirs(work_dir)
cfg.dump(osp.join(work_dir, osp.basename(args.config)))

trainer.train_and_valid()

And I got the error :

Traceback (most recent call last):
  File "C:\Users\HuZetian\Desktop\TradeMaster\tutorial\Tutorial_DeepTrader.py", line 72, in <module>
    trainer.train_and_valid()
  File "C:\Users\HuZetian\Desktop\TradeMaster\trademaster\trainers\portfolio_management\deeptrader_trainer.py", line 173, in train_and_valid
    buffer_items = self.agent.explore_env(self.train_environment, self.horizon_len)
  File "C:\Users\HuZetian\Desktop\TradeMaster\trademaster\agents\portfolio_management\deeptrader.py", line 216, in explore_env    
    action = get_action(state,market_state,corr_matrix)
  File "C:\Users\HuZetian\Desktop\TradeMaster\trademaster\agents\portfolio_management\deeptrader.py", line 173, in get_action 
    asset_scores = self.act(state, corr_matrix)
  File "C:\ProgramData\Anaconda3\envs\TradeMaster\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl     
    return forward_call(*input, **kwargs)
  File "C:\Users\HuZetian\Desktop\TradeMaster\trademaster\nets\deeptrader.py", line 221, in forward
    H_L = self.TCN(x)
  File "C:\ProgramData\Anaconda3\envs\TradeMaster\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl     
    return forward_call(*input, **kwargs)
  File "C:\Users\HuZetian\Desktop\TradeMaster\trademaster\nets\deeptrader.py", line 135, in forward
    return self.network(x)
  File "C:\ProgramData\Anaconda3\envs\TradeMaster\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl     
    return forward_call(*input, **kwargs)
  File "C:\ProgramData\Anaconda3\envs\TradeMaster\lib\site-packages\torch\nn\modules\container.py", line 139, in forward      
    input = module(input)
  File "C:\ProgramData\Anaconda3\envs\TradeMaster\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl     
    return forward_call(*input, **kwargs)
  File "C:\Users\HuZetian\Desktop\TradeMaster\trademaster\nets\deeptrader.py", line 89, in forward
    out = self.net(x)
  File "C:\ProgramData\Anaconda3\envs\TradeMaster\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl     
    return forward_call(*input, **kwargs)
  File "C:\ProgramData\Anaconda3\envs\TradeMaster\lib\site-packages\torch\nn\modules\container.py", line 139, in forward      
    input = module(input)
  File "C:\ProgramData\Anaconda3\envs\TradeMaster\lib\site-packages\torch\nn\modules\module.py", line 1148, in _call_impl     
    result = forward_call(*input, **kwargs)
  File "C:\ProgramData\Anaconda3\envs\TradeMaster\lib\site-packages\torch\nn\modules\conv.py", line 307, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "C:\ProgramData\Anaconda3\envs\TradeMaster\lib\site-packages\torch\nn\modules\conv.py", line 303, in _conv_forward     
    return F.conv1d(input, weight, bias, self.stride,
RuntimeError: Expected 2D (unbatched) or 3D (batched) input to conv1d, but got input of size: [1, 29, 16, 10]

Can help me with it?

DVampire commented 1 year ago

There are some problems in deeptrader we haven't finished dealing with yet, please don't try to use it yet, we will fix it as soon as possible, thanks.

Gikiman commented 1 year ago

Suer, thx for your effort