kevslinger / DTQN

Deep Transformer Q-Networks for Partially Observable Reinforcement Learning
MIT License
130 stars 21 forks source link

Matrix size mismatch error #12

Open ashok-arora opened 1 month ago

ashok-arora commented 1 month ago

Hey Kevin, I am facing the following error when running DQN on the heavenhell environment.

Error log:

$ python run.py --env POMDP-heavenhell_3-episodic-v0 --inembed 64 --model DQN --verbose --seed 1 --disable-wandb
Loading using gym.make
Loading using gym.make
[ July 12, 22:58:08 ] Creating DQN with 5132 parameters
Traceback (most recent call last):
  File "run.py", line 333, in <module>
    run_experiment(parser.parse_args())
  File "run.py", line 115, in run_experiment
    agent.train()
  File "/home/cse/Desktop/ashok/DTQN-paper/dtqn/agents/dqn.py", line 189, in train
    q_values = self.policy_network(obss).gather(1, actions).squeeze()
  File "/home/cse/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/cse/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/cse/Desktop/ashok/DTQN-paper/dtqn/networks/dqn.py", line 47, in forward
    return self.ffn(self.obs_embed(x))
  File "/home/cse/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/cse/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/cse/Desktop/ashok/DTQN-paper/dtqn/networks/representations.py", line 12, in forward
    return self.embedding(obs)
  File "/home/cse/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/cse/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/cse/.local/lib/python3.8/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/home/cse/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/cse/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/cse/.local/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x256 and 8x64)
kevslinger commented 1 month ago

Hi Ashok, thanks for pointing this out. It looks like obss got flattened incorrectly -- I believe obss should be 32x8 (batch-size x obs-embed) but instead it got flattened to 1x256. I believe it is a bug in https://github.com/kevslinger/DTQN/blob/bc46144bf96eaffba5ab6676ef1c3dde6cdf2d93/dtqn/networks/representations.py#L15 specifically for DQN (the nn.Flatten(start_dim=-2). If my memory serves, changing that to start_dim=-1 specifically for DQN will fix the issue, but I just left for holiday so I won't be ble to test that myself for a bit.

ashok-arora commented 1 month ago

I appreciate your reply and apologize for disturbing you during your holiday. I'll try using start_dim=-1 for DQN only and let you know how it goes.

kevslinger commented 1 month ago

No worries, sounds good 👍

ashok-arora commented 1 month ago

Hey, changing it to -1 gave a different traceback:

(dtqn) user@user-HD-SW200:~/Desktop/ashok/DTQN-paper$ python run.py --env gv_memory.5x5.yaml --inembed 128 --model DQN --verbose --seed 1  --device cuda:1 && python run.py --env gv_memory.5x5.yaml --inembed 128 --model DQN --verbose --seed 2  --device cuda:1 && python run.py --env gv_memory.5x5.yaml --inembed 128 --model DQN --verbose --seed 3  --device cuda:1 && python run.py --env gv_memory.5x5.yaml  --inembed 128 --model DRQN --verbose --seed 4  --device cuda:1 && python run.py --env gv_memory.5x5.yaml --inembed 128 --model DQN --verbose --seed 5  --device cuda:1
Loading using gym.make
Environment with id gv_memory.5x5.yaml not found.
Loading using YAML
Loading using gym.make
Environment with id gv_memory.5x5.yaml not found.
Loading using YAML
[ July 17, 19:55:00 ] Creating DQN with 23742 parameters
wandb: (1) Create a W&B account
wandb: (2) Use an existing W&B account
wandb: (3) Don't visualize my results
wandb: Enter your choice: 2
wandb: You chose 'Use an existing W&B account'
wandb: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
wandb: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
wandb: Appending key for api.wandb.ai to your netrc file: /home/user/.netrc
wandb: Tracking run with wandb version 0.17.4
wandb: Run data is saved locally in /home/user/Desktop/ashok/DTQN-paper/wandb/run-20240717_195531-1jz7xoux
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run super-vortex-117
wandb: ⭐️ View project at https://wandb.ai/arora8715/DTQN-Neurips2022
wandb: 🚀 View run at https://wandb.ai/arora8715/DTQN-Neurips2022/runs/1jz7xoux
Traceback (most recent call last):
  File "run.py", line 333, in <module>
    run_experiment(parser.parse_args())
  File "run.py", line 115, in run_experiment
    agent.train()
  File "/home/user/Desktop/ashok/DTQN-paper/dtqn/agents/dqn.py", line 189, in train
    q_values = self.policy_network(obss).gather(1, actions).squeeze()
  File "/home/user/anaconda3/envs/dtqn/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/anaconda3/envs/dtqn/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/Desktop/ashok/DTQN-paper/dtqn/networks/dqn.py", line 47, in forward
    return self.ffn(self.obs_embed(x))
  File "/home/user/anaconda3/envs/dtqn/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/anaconda3/envs/dtqn/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/Desktop/ashok/DTQN-paper/dtqn/networks/representations.py", line 12, in forward
    return self.embedding(obs)
  File "/home/user/anaconda3/envs/dtqn/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/anaconda3/envs/dtqn/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/anaconda3/envs/dtqn/lib/python3.8/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/home/user/anaconda3/envs/dtqn/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/anaconda3/envs/dtqn/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/anaconda3/envs/dtqn/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (192x8 and 48x128)
kevslinger commented 1 month ago

Did it work correctly for heavenhell? If so, then changing it back to nn.Flatten(start_dim=-2) should work for the gridverse environments. The difference is that heavenhell returns observations as an integer whereas gridverse returns observations as a 2d array, so we need to account for that extra dimensino when we flatten.

As an aside, this complication (only DQN needs to be flattened at -1 and only when its on environments that return flat (integer) observations) is why I never tried to include that in the code

ashok-arora commented 1 month ago

Thank you for the response. Yes, changing it back to nn.Flatten(start_dim=-2) works for gridverse environments.

As an aside, this complication (only DQN needs to be flattened at -1 and only when its on environments that return flat (integer) observations) is why I never tried to include that in the code

Since the DQN also uses embeddings, why is it the only model with this issue?

kevslinger commented 1 month ago

It's because DQN is the only model which only uses 1 observation as input to the model instead of a sequence of observations