Open ashok-arora opened 1 month ago
Hi Ashok, thanks for pointing this out. It looks like obss
got flattened incorrectly -- I believe obss
should be 32x8
(batch-size x obs-embed) but instead it got flattened to 1x256. I believe it is a bug in https://github.com/kevslinger/DTQN/blob/bc46144bf96eaffba5ab6676ef1c3dde6cdf2d93/dtqn/networks/representations.py#L15 specifically for DQN (the nn.Flatten(start_dim=-2)
. If my memory serves, changing that to start_dim=-1
specifically for DQN will fix the issue, but I just left for holiday so I won't be ble to test that myself for a bit.
I appreciate your reply and apologize for disturbing you during your holiday. I'll try using start_dim=-1
for DQN only and let you know how it goes.
No worries, sounds good 👍
Hey, changing it to -1
gave a different traceback:
(dtqn) user@user-HD-SW200:~/Desktop/ashok/DTQN-paper$ python run.py --env gv_memory.5x5.yaml --inembed 128 --model DQN --verbose --seed 1 --device cuda:1 && python run.py --env gv_memory.5x5.yaml --inembed 128 --model DQN --verbose --seed 2 --device cuda:1 && python run.py --env gv_memory.5x5.yaml --inembed 128 --model DQN --verbose --seed 3 --device cuda:1 && python run.py --env gv_memory.5x5.yaml --inembed 128 --model DRQN --verbose --seed 4 --device cuda:1 && python run.py --env gv_memory.5x5.yaml --inembed 128 --model DQN --verbose --seed 5 --device cuda:1
Loading using gym.make
Environment with id gv_memory.5x5.yaml not found.
Loading using YAML
Loading using gym.make
Environment with id gv_memory.5x5.yaml not found.
Loading using YAML
[ July 17, 19:55:00 ] Creating DQN with 23742 parameters
wandb: (1) Create a W&B account
wandb: (2) Use an existing W&B account
wandb: (3) Don't visualize my results
wandb: Enter your choice: 2
wandb: You chose 'Use an existing W&B account'
wandb: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
wandb: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:
wandb: Appending key for api.wandb.ai to your netrc file: /home/user/.netrc
wandb: Tracking run with wandb version 0.17.4
wandb: Run data is saved locally in /home/user/Desktop/ashok/DTQN-paper/wandb/run-20240717_195531-1jz7xoux
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run super-vortex-117
wandb: ⭐️ View project at https://wandb.ai/arora8715/DTQN-Neurips2022
wandb: 🚀 View run at https://wandb.ai/arora8715/DTQN-Neurips2022/runs/1jz7xoux
Traceback (most recent call last):
File "run.py", line 333, in <module>
run_experiment(parser.parse_args())
File "run.py", line 115, in run_experiment
agent.train()
File "/home/user/Desktop/ashok/DTQN-paper/dtqn/agents/dqn.py", line 189, in train
q_values = self.policy_network(obss).gather(1, actions).squeeze()
File "/home/user/anaconda3/envs/dtqn/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/user/anaconda3/envs/dtqn/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/user/Desktop/ashok/DTQN-paper/dtqn/networks/dqn.py", line 47, in forward
return self.ffn(self.obs_embed(x))
File "/home/user/anaconda3/envs/dtqn/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/user/anaconda3/envs/dtqn/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/user/Desktop/ashok/DTQN-paper/dtqn/networks/representations.py", line 12, in forward
return self.embedding(obs)
File "/home/user/anaconda3/envs/dtqn/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/user/anaconda3/envs/dtqn/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/user/anaconda3/envs/dtqn/lib/python3.8/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
File "/home/user/anaconda3/envs/dtqn/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/user/anaconda3/envs/dtqn/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/user/anaconda3/envs/dtqn/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 116, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (192x8 and 48x128)
Did it work correctly for heavenhell? If so, then changing it back to nn.Flatten(start_dim=-2)
should work for the gridverse
environments. The difference is that heavenhell returns observations as an integer whereas gridverse returns observations as a 2d array, so we need to account for that extra dimensino when we flatten.
As an aside, this complication (only DQN needs to be flattened at -1 and only when its on environments that return flat (integer) observations) is why I never tried to include that in the code
Thank you for the response. Yes, changing it back to nn.Flatten(start_dim=-2)
works for gridverse environments.
As an aside, this complication (only DQN needs to be flattened at -1 and only when its on environments that return flat (integer) observations) is why I never tried to include that in the code
Since the DQN also uses embeddings, why is it the only model with this issue?
It's because DQN is the only model which only uses 1 observation as input to the model instead of a sequence of observations
Hey Kevin, I am facing the following error when running DQN on the heavenhell environment.
Error log: