Khrylx / PyTorch-RL

PyTorch implementation of Deep Reinforcement Learning: Policy Gradient methods (TRPO, PPO, A2C) and Generative Adversarial Imitation Learning (GAIL). Fast Fisher vector product TRPO.
MIT License
1.09k stars 186 forks source link

Inconsistent action shape when running CartPole-v1 #13

Closed truongthanh96 closed 5 years ago

truongthanh96 commented 5 years ago

Gym env: CartPole-v1 Affected code File: gailgym.py


    """update discriminator"""
    for _ in range(1):
        expert_state_actions = torch.from_numpy(expert_traj).to(dtype).to(device)

        g_o = discrim_net(torch.cat([states, actions], 1))
        e_o = discrim_net(expert_state_actions)
        optimizer_discrim.zero_grad()
        discrim_loss = discrim_criterion(g_o, ones((states.shape[0], 1), device=device)) + \
            discrim_criterion(e_o, zeros((expert_traj.shape[0], 1), device=device))
        discrim_loss.backward()
        optimizer_discrim.step()

Error

 g_o = discrim_net(torch.cat([states, actions], 1))
 RuntimeError: invalid argument 0: Tensors must have same number of dimensions: got 2 and 1 at /opt/conda/conda-bld/pytorch-cpu_1544218188686/work/aten/src/TH/generic/THTensorMoreMath.cpp:1324

To reproduce this error: python ./gail/save_expert_traj.py --model-path assets/learned_models/CartPole-v1_trpo.p --env-name CartPole-v1 --save-model-interval 100 python ./gail/gail_gym.py --env-name CartPole-v1 --expert-traj-path assets/expert_traj/CartPole-v1_expert_traj.p

This happended because CartPole-v1 's action is discrete, hence:
state = [[0.1,0.2,0.3,0.4],[0.1,0.2,0.3,0.4]]
action = [1,0]
When performing torch.cat thrown this error
Fix suggestion

  """update discriminator"""
    for _ in range(1):
        expert_state_actions = torch.from_numpy(expert_traj).to(dtype).to(device)

        if len(actions.shape) == 1:
            g_o = discrim_net(torch.cat([states, actions.unsqueeze(1)], 1))
        else:
            g_o = discrim_net(torch.cat([states, actions], 1))
        e_o = discrim_net(expert_state_actions)
        optimizer_discrim.zero_grad()
        discrim_loss = discrim_criterion(g_o, ones((states.shape[0], 1), device=device)) + \
            discrim_criterion(e_o, zeros((expert_traj.shape[0], 1), device=device))
        discrim_loss.backward()
        optimizer_discrim.step()

save_expert_traj doesn't cause error since np.hstack stack element instead concating them. expert_traj.append(np.hstack([state, action]))

Khrylx commented 5 years ago

Hi,

Thanks for pointing this out. I wasn’t intending to implement GAIL for discrete environments. I think it is not so straightforward to extend GAN/GAIL to discrete settings, although something like Gumbel-SoftMax could work.

Ye