AI4Finance-Foundation / FinRL-Tutorials

Tutorials. Please star.
https://ai4finance.org
MIT License
848 stars 349 forks source link

Error in FinRL_PortfolioAllocation_Explainable_DRL #39

Open Mahdiehdn opened 1 year ago

Mahdiehdn commented 1 year ago

Hello I add deep network to code and code have error. can you help me? thanks. part of code:

class Net(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Box, stock_dim: int = 28):
        super(Net, self).__init__(observation_space, stock_dim)
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc3 = nn.Linear(672,stock_dim)
`

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = nn.Flatten()(x)
       # x = x.view(x.size(0), -1)
        x = F.relu(self.fc3(x))
        return x
policy_kwargs ={
    'features_extractor_class':Net,
}
model = PPO("MlpPolicy", env= env_train, policy_kwargs=policy_kwargs)
model.learn(total_timesteps=100000)

Error:

ValueError Traceback (most recent call last) in ----> 1 model.learn(total_timesteps=100000)

3 frames /usr/local/lib/python3.8/dist-packages/stable_baselines3/common/buffers.py in add(self, obs, action, reward, episode_start, value, log_prob) 435 436 # Same reshape, for actions --> 437 action = action.reshape((self.n_envs, self.action_dim)) 438 439 self.observations[self.pos] = np.array(obs).copy()

ValueError: cannot reshape array of size 1792 into shape (1,28)