siddhanthaldar / BAKU

Code for BAKU: An Efficient Transformer for Multi-Task Policy Learning
https://baku-robot.github.io/
76 stars 8 forks source link

RuntimeError:Size mismatch when Evaluate BAKU on metaworld #12

Open 4evertutelary opened 2 weeks ago

4evertutelary commented 2 weeks ago

Dear author: Thanks a lot for your great contribution to multi task policy learning.

Any suggestion on debugging followed issue? When I run followed cmdline, I got runtime error: RuntimeError: Error(s) in loading state_dict for Actor: size mismatch for _policy.transformer.wpe.weight: copying a param with shape torch.Size([51, 256]) from checkpoint, the shape in current model is torch.Size([65, 256]).

cmdline: python eval.py agent=baku suite=metaworld dataloader=metaworld suite/task=metaworld suite.hidden_dim=256 bc_weight=/home/zzz/code/BAKU/baku-weights/weights/metaworld/baku.pt use_proprio=false

detailed log output: /home/zzz/miniforge3/envs/baku/bin/python /home/zzz/code/BAKU/baku/eval.py agent=baku suite=metaworld dataloader=metaworld suite.hidden_dim=256 bc_weight=/home/zzz/code/BAKU/baku-weights/weights/metaworld/baku.pt use_proprio=false workspace: /home/zzz/code/BAKU/baku/exp_local/eval/2024.11.04_eval/deterministic/154147_hidden_dim_256 Loading /home/zzz/code/BAKU/expert_demos/metaworld/basketball.pkl Loading /home/zzz/code/BAKU/expert_demos/metaworld/bin_picking.pkl Loading /home/zzz/code/BAKU/expert_demos/metaworld/button_press.pkl Loading /home/zzz/code/BAKU/expert_demos/metaworld/button_press_topdown.pkl Loading /home/zzz/code/BAKU/expert_demos/metaworld/button_press_topdown_wall.pkl Loading /home/zzz/code/BAKU/expert_demos/metaworld/button_press_wall.pkl Loading /home/zzz/code/BAKU/expert_demos/metaworld/coffee_button.pkl Loading /home/zzz/code/BAKU/expert_demos/metaworld/coffee_pull.pkl Loading /home/zzz/code/BAKU/expert_demos/metaworld/coffee_push.pkl Loading /home/zzz/code/BAKU/expert_demos/metaworld/dial_turn.pkl Loading /home/zzz/code/BAKU/expert_demos/metaworld/disassemble.pkl Loading /home/zzz/code/BAKU/expert_demos/metaworld/door_lock.pkl Loading /home/zzz/code/BAKU/expert_demos/metaworld/door_open.pkl Loading /home/zzz/code/BAKU/expert_demos/metaworld/door_unlock.pkl Loading /home/zzz/code/BAKU/expert_demos/metaworld/drawer_close.pkl Loading /home/zzz/code/BAKU/expert_demos/metaworld/drawer_open.pkl Loading /home/zzz/code/BAKU/expert_demos/metaworld/faucet_close.pkl Loading /home/zzz/code/BAKU/expert_demos/metaworld/faucet_open.pkl Loading /home/zzz/code/BAKU/expert_demos/metaworld/hammer.pkl Loading /home/zzz/code/BAKU/expert_demos/metaworld/handle_press.pkl Loading /home/zzz/code/BAKU/expert_demos/metaworld/handle_press_side.pkl Loading /home/zzz/code/BAKU/expert_demos/metaworld/handle_pull.pkl Loading /home/zzz/code/BAKU/expert_demos/metaworld/handle_pull_side.pkl Loading /home/zzz/code/BAKU/expert_demos/metaworld/shelf_place.pkl Loading /home/zzz/code/BAKU/expert_demos/metaworld/soccer.pkl Loading /home/zzz/code/BAKU/expert_demos/metaworld/stick_push.pkl Loading /home/zzz/code/BAKU/expert_demos/metaworld/sweep.pkl Loading /home/zzz/code/BAKU/expert_demos/metaworld/sweep_into.pkl Loading /home/zzz/code/BAKU/expert_demos/metaworld/window_close.pkl Loading /home/zzz/code/BAKU/expert_demos/metaworld/window_open.pkl

[2024-11-04 15:41:49,172][sentence_transformers.SentenceTransformer][INFO] - Use pytorch device_name: cuda [2024-11-04 15:41:49,172][sentence_transformers.SentenceTransformer][INFO] - Load pretrained SentenceTransformer: sentence-transformers/all-MiniLM-L6-v2 /home/zzz/miniforge3/envs/baku/lib/python3.9/site-packages/gymnasium/spaces/box.py:235: UserWarning: WARN: Box low's precision lowered by casting to float32, current low.dtype=float64 gym.logger.warn( /home/zzz/miniforge3/envs/baku/lib/python3.9/site-packages/gymnasium/spaces/box.py:305: UserWarning: WARN: Box high's precision lowered by casting to float32, current high.dtype=float64 gym.logger.warn( Batches: 100%|██████████| 1/1 [00:00<00:00, 3.22it/s] Batches: 100%|██████████| 1/1 [00:00<00:00, 155.64it/s] Batches: 100%|██████████| 1/1 [00:00<00:00, 465.31it/s] Batches: 100%|██████████| 1/1 [00:00<00:00, 433.34it/s] Batches: 100%|██████████| 1/1 [00:00<00:00, 445.07it/s] Batches: 100%|██████████| 1/1 [00:00<00:00, 453.24it/s] Batches: 100%|██████████| 1/1 [00:00<00:00, 476.03it/s] Batches: 100%|██████████| 1/1 [00:00<00:00, 456.10it/s] Batches: 100%|██████████| 1/1 [00:00<00:00, 473.34it/s] Batches: 100%|██████████| 1/1 [00:00<00:00, 476.19it/s] Batches: 100%|██████████| 1/1 [00:00<00:00, 476.52it/s] Batches: 100%|██████████| 1/1 [00:00<00:00, 469.48it/s] Batches: 100%|██████████| 1/1 [00:00<00:00, 467.59it/s] Batches: 100%|██████████| 1/1 [00:00<00:00, 475.11it/s] Batches: 100%|██████████| 1/1 [00:00<00:00, 469.95it/s] Batches: 100%|██████████| 1/1 [00:00<00:00, 480.67it/s] Batches: 100%|██████████| 1/1 [00:00<00:00, 482.49it/s] Batches: 100%|██████████| 1/1 [00:00<00:00, 474.90it/s] Batches: 100%|██████████| 1/1 [00:00<00:00, 488.39it/s] Batches: 100%|██████████| 1/1 [00:00<00:00, 469.95it/s] Batches: 100%|██████████| 1/1 [00:00<00:00, 472.49it/s] Batches: 100%|██████████| 1/1 [00:00<00:00, 464.64it/s] Batches: 100%|██████████| 1/1 [00:00<00:00, 440.53it/s] Batches: 100%|██████████| 1/1 [00:00<00:00, 471.16it/s] Batches: 100%|██████████| 1/1 [00:00<00:00, 472.70it/s] Batches: 100%|██████████| 1/1 [00:00<00:00, 443.23it/s] Batches: 100%|██████████| 1/1 [00:00<00:00, 476.25it/s] Batches: 100%|██████████| 1/1 [00:00<00:00, 471.54it/s] Batches: 100%|██████████| 1/1 [00:00<00:00, 452.90it/s] Batches: 100%|██████████| 1/1 [00:00<00:00, 466.86it/s] /home/zzz/code/BAKU/baku/agent/networks/utils/vqbet/vector_quantize_pytorch/vector_quantize_pytorch.py:501: FutureWarning: torch.cuda.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cuda', args...) instead. @autocast(enabled=False) /home/zzz/code/BAKU/baku/agent/networks/utils/vqbet/vector_quantize_pytorch/vector_quantize_pytorch.py:703: FutureWarning: torch.cuda.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cuda', args...) instead. @autocast(enabled=False) /home/zzz/miniforge3/envs/baku/lib/python3.9/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead. warnings.warn( /home/zzz/miniforge3/envs/baku/lib/python3.9/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or None for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing weights=None. warnings.warn(msg) /home/zzz/miniforge3/envs/baku/lib/python3.9/site-packages/torch/functional.py:534: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /opt/conda/conda-bld/pytorch_1729647380992/work/aten/src/ATen/native/TensorShape.cpp:3595.) return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined] number of parameters: 6.53M loading bc weight: /home/zzz/code/BAKU/baku-weights/weights/metaworld/baku.pt /home/zzz/code/BAKU/baku/eval.py:179: FutureWarning: You are using torch.load with weights_only=False (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for weights_only will be flipped to True. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via torch.serialization.add_safe_globals. We recommend you start setting weights_only=True for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. payload = torch.load(f) Error executing job with overrides: ['agent=baku', 'suite=metaworld', 'dataloader=metaworld', 'suite.hidden_dim=256', 'bc_weight=/home/zzz/code/BAKU/baku-weights/weights/metaworld/baku.pt', 'use_proprio=false'] Traceback (most recent call last): File "/home/zzz/code/BAKU/baku/eval.py", line 212, in main() File "/home/zzz/miniforge3/envs/baku/lib/python3.9/site-packages/hydra/main.py", line 49, in decorated_main _run_hydra( File "/home/zzz/miniforge3/envs/baku/lib/python3.9/site-packages/hydra/_internal/utils.py", line 367, in _run_hydra run_and_report( File "/home/zzz/miniforge3/envs/baku/lib/python3.9/site-packages/hydra/_internal/utils.py", line 214, in run_and_report raise ex File "/home/zzz/miniforge3/envs/baku/lib/python3.9/site-packages/hydra/_internal/utils.py", line 211, in run_and_report return func() File "/home/zzz/miniforge3/envs/baku/lib/python3.9/site-packages/hydra/_internal/utils.py", line 368, in lambda: hydra.run( File "/home/zzz/miniforge3/envs/baku/lib/python3.9/site-packages/hydra/internal/hydra.py", line 110, in run = ret.return_value File "/home/zzz/miniforge3/envs/baku/lib/python3.9/site-packages/hydra/core/utils.py", line 233, in return_value raise self._return_value File "/home/zzz/miniforge3/envs/baku/lib/python3.9/site-packages/hydra/core/utils.py", line 160, in run_job ret.return_value = task_function(task_cfg) File "/home/zzz/code/BAKU/baku/eval.py", line 206, in main workspace.load_snapshot(snapshots) File "/home/zzz/code/BAKU/baku/eval.py", line 188, in load_snapshot self.agent.load_snapshot(agent_payload, eval=True) File "/home/zzz/code/BAKU/baku/agent/baku.py", line 915, in load_snapshot self.dict[k].load_state_dict(payload[k]) File "/home/zzz/miniforge3/envs/baku/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2584, in load_state_dict raise RuntimeError( RuntimeError: Error(s) in loading state_dict for Actor: size mismatch for _policy.transformer.wpe.weight: copying a param with shape torch.Size([51, 256]) from checkpoint, the shape in current model is torch.Size([65, 256]). size mismatch for _policy.transformer.h.0.attn.bias: copying a param with shape torch.Size([1, 1, 51, 51]) from checkpoint, the shape in current model is torch.Size([1, 1, 65, 65]). size mismatch for _policy.transformer.h.1.attn.bias: copying a param with shape torch.Size([1, 1, 51, 51]) from checkpoint, the shape in current model is torch.Size([1, 1, 65, 65]). size mismatch for _policy.transformer.h.2.attn.bias: copying a param with shape torch.Size([1, 1, 51, 51]) from checkpoint, the shape in current model is torch.Size([1, 1, 65, 65]). size mismatch for _policy.transformer.h.3.attn.bias: copying a param with shape torch.Size([1, 1, 51, 51]) from checkpoint, the shape in current model is torch.Size([1, 1, 65, 65]). size mismatch for _policy.transformer.h.4.attn.bias: copying a param with shape torch.Size([1, 1, 51, 51]) from checkpoint, the shape in current model is torch.Size([1, 1, 65, 65]). size mismatch for _policy.transformer.h.5.attn.bias: copying a param with shape torch.Size([1, 1, 51, 51]) from checkpoint, the shape in current model is torch.Size([1, 1, 65, 65]). size mismatch for _policy.transformer.h.6.attn.bias: copying a param with shape torch.Size([1, 1, 51, 51]) from checkpoint, the shape in current model is torch.Size([1, 1, 65, 65]). size mismatch for _policy.transformer.h.7.attn.bias: copying a param with shape torch.Size([1, 1, 51, 51]) from checkpoint, the shape in current model is torch.Size([1, 1, 65, 65]). [2024-11-04 15:42:11,233][OpenGL.platform.baseplatform][INFO] - Platform import failed (likely during shutdown) [2024-11-04 15:42:11,233][OpenGL.platform.baseplatform][INFO] - Platform import failed (likely during shutdown) [2024-11-04 15:42:11,234][OpenGL.platform.baseplatform][INFO] - Platform import failed (likely during shutdown) [2024-11-04 15:42:11,234][OpenGL.platform.baseplatform][INFO] - Platform import failed (likely during shutdown) [2024-11-04 15:42:11,234][OpenGL.platform.baseplatform][INFO] - Platform import failed (likely during shutdown) [2024-11-04 15:42:11,234][OpenGL.platform.baseplatform][INFO] - Platform import failed (likely during shutdown) ......

4evertutelary commented 2 weeks ago

I just used pretrained weights from https://osf.io/3x8v5/?view_only=fb8285f025e84d23a41a0eef683a7e6d

4evertutelary commented 2 weeks ago

Followed change will resolve the issue:

/BAKU/baku/agent/baku.py GPT model if policy_type == "gpt": self._policy = GPT( GPTConfig(

block_size=65,

        block_size=51, #zzz
4evertutelary commented 2 weeks ago

@siddhanthaldar Could you have a look when you have free time?

siddhanthaldar commented 2 weeks ago

Hi,

Yes, that change should fix the issue. The issue arises from having to use a fixed block length for the transformer (and I had to use different lengths for different environment suites). When training your own model, you can set this value to a large enough value for the context length to not require changing during later stages.

Thanks for pointing this out. Let me know if you have any other questions.

4evertutelary commented 1 week ago

Seems for meta-world basketball case, we should pass camera_name to mujoco?

diff --git a/metaworld/envs/mujoco/sawyer_xyz/sawyer_xyz_env.py b/metaworld/envs/mujoco/sawyer_xyz/sawyer_xyz_env.py index a0e9c1f..d21e1ac 100644 --- a/metaworld/envs/mujoco/sawyer_xyz/sawyer_xyz_env.py +++ b/metaworld/envs/mujoco/sawyer_xyz/sawyer_xyz_env.py @@ -33,6 +33,7 @@ class SawyerMocapBase(mjenv_gym): frame_skip=frame_skip, observation_space=self.sawyer_observation_space, render_mode=render_mode,