Shmuma / ptan

PyTorch Agent Net: reinforcement learning toolkit for pytorch
MIT License
531 stars 165 forks source link

a2c.py and a2c_atari.py throw error with sample run files #6

Open ghost opened 6 years ago

ghost commented 6 years ago

For example when I run a2c.py -r "runs/a2c/a2c_cartpole.ini" tons of errors pop up.

Regardless I like that you've implemented a lot of algorithms and put them here. It's very useful for someone new to RL like me, I'm mainly just reading through the code to figure out what is going on. It's just a shame the samples don't seem to be working as intended. :(

swenner commented 6 years ago

Fix for the first issue.

--- a/samples/a2c.py
+++ b/samples/a2c.py
@@ -93,10 +96,12 @@ if __name__ == "__main__":

     # model returns probability of actions
     model = Model(env.action_space.n, env.observation_space.shape[0])
+    device="cpu"
     if cuda_enabled:
         model.cuda()
+        device = "cuda"

-    agent = ptan.agent.PolicyAgent(a3c_actor_wrapper(model), cuda=cuda_enabled)
+    agent = ptan.agent.PolicyAgent(a3c_actor_wrapper(model), device=device)
     exp_source = ptan.experience.ExperienceSource(env=env, agent=agent, steps_count=run.getint("learning", "n_steps"))

     optimizer = optim.RMSprop(model.parameters(), lr=run.getfloat("learning", "lr"))

But then I hit the next one:

python3 a2c.py -r runs/a2c/a2c_cartpole.ini 
Traceback (most recent call last):
  File "a2c.py", line 172, in 
    for exp in exp_source:
  File "/home/simon/ptan-git/ptan/experience.py", line 82, in __iter__
    states_actions, new_agent_states = self.agent(states_input, agent_states)
  File "/home/simon/ptan-git/ptan/agent.py", line 129, in __call__
    probs_v = self.model(states)
  File "a2c.py", line 64, in _wrap
    x = model(x)
  File "/home/simon/.local/lib/python3.5/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "a2c.py", line 53, in forward
    x = self.fc1(x)
  File "/home/simon/.local/lib/python3.5/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/simon/.local/lib/python3.5/site-packages/torch/nn/modules/linear.py", line 55, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/simon/.local/lib/python3.5/site-packages/torch/nn/functional.py", line 1024, in linear
    return torch.addmm(bias, input, weight.t())
RuntimeError: Expected object of type torch.FloatTensor but found type torch.DoubleTensor for argument #4 'mat1'