HaoyiZhu / PointCloudMatters

[NeurIPS 2024 D&B] Point Cloud Matters: Rethinking the Impact of Different Observation Spaces on Robot Learning
https://haoyizhu.github.io/pcm/
MIT License
45 stars 1 forks source link

Cannot run pretrained multimae on RLBench #3

Closed Fisher-Wang closed 2 months ago

Fisher-Wang commented 2 months ago

Thanks for your great work!

When I run

python src/train.py exp_rlbench_diffusion_policy=base rlbench_task=turn_tap exp_rlbench_diffusion_policy/rlbench_model@rlbench_model=pretrained_multimae_rgbd seed=0

I got this error

Traceback (most recent call last):
  File "/home/xxx/miniforge3/envs/pcm/lib/python3.11/site-packages/hydra/_internal/instantiate/_instantiate2.py", line 92, in _call_target
    return _target_(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxx/cod/PointCloudMatters/src/models/components/diffusion_policy/diffusion_unet_image_policy.py", line 49, in __init__
    obs_feature_dim = obs_encoder.output_shape()[0]
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxx/miniforge3/envs/pcm/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxx/cod/PointCloudMatters/src/models/components/diffusion_policy/vision/multi_image_obs_encoder.py", line 242, in output_shape
    example_output = self.forward(example_obs_dict)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxx/cod/PointCloudMatters/src/models/components/diffusion_policy/vision/multi_image_obs_encoder.py", line 187, in forward
    feature = feature.reshape(-1, batch_size, *feature.shape[1:])
              ^^^^^^^^^^^^^^^
AttributeError: 'dict' object has no attribute 'reshape'

Also, when I run

python src/train.py exp_rlbench_act_policy=base rlbench_task=turn_tap exp_rlbench_act_policy/rlbench_model@rlbench_model=pretrained_multimae_rgbd seed=0

I got this error

Traceback (most recent call last):
  File "/home/xxx/cod/PointCloudMatters/src/utils/utils.py", line 70, in wrap
    metric_dict, object_dict = task_func(cfg=cfg)
                               ^^^^^^^^^^^^^^^^^^
  File "/home/xxx/cod/PointCloudMatters/src/train.py", line 93, in train
    trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
  File "/home/xxx/miniforge3/envs/pcm/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
    call._call_and_handle_interrupt(
  File "/home/xxx/miniforge3/envs/pcm/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxx/miniforge3/envs/pcm/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/xxx/miniforge3/envs/pcm/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/home/xxx/miniforge3/envs/pcm/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1023, in _run_stage
    self._run_sanity_check()
  File "/home/xxx/miniforge3/envs/pcm/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1052, in _run_sanity_check
    val_loop.run()
  File "/home/xxx/miniforge3/envs/pcm/lib/python3.11/site-packages/lightning/pytorch/loops/utilities.py", line 178, in _decorator
    return loop_run(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxx/miniforge3/envs/pcm/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 135, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/home/xxx/miniforge3/envs/pcm/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 396, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxx/miniforge3/envs/pcm/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 319, in _call_strategy_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/xxx/miniforge3/envs/pcm/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 411, in validation_step
    return self.lightning_module.validation_step(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxx/cod/PointCloudMatters/src/models/rlbench_act_bc_module.py", line 91, in validation_step
    loss_dict = self.model_step(batch)
                ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxx/cod/PointCloudMatters/src/models/rlbench_act_bc_module.py", line 63, in model_step
    return self.policy(batch)
           ^^^^^^^^^^^^^^^^^^
  File "/home/xxx/miniforge3/envs/pcm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxx/miniforge3/envs/pcm/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxx/cod/PointCloudMatters/src/models/components/act/act.py", line 299, in forward
    data_dict = self.forward_obs_embed(data_dict)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xxx/cod/PointCloudMatters/src/models/components/act/act.py", line 218, in forward_obs_embed
    print(features.shape)
          ^^^^^^^^^^^^^^
AttributeError: 'dict' object has no attribute 'shape'

Cloud you help me run these experiment correctly?

HaoyiZhu commented 2 months ago

Hi, I have fixed it. The model class should be MultiViT instead of MultiMAE (which is used for pre-training).

Thanks for pointing out,

Fisher-Wang commented 2 months ago

Thanks for clarification!

Since MultiMAE is used for pre-training, then maybe this line for scratch_multivit_rgbd.yaml should be MultiMAEModel? https://github.com/HaoyiZhu/PointCloudMatters/blob/3e751d30f34515242eb25f7f3da767c3ec643f14/configs/exp_rlbench_act_policy/rlbench_model/scratch_multivit_rgbd.yaml#L20

Also, the filenames should also be changed accordingly?

xiaoxiao0406 commented 2 months ago

This line should always be MultiViTModel. For scratch setting, you just have to set the ckpt_path to be None.

Fisher-Wang commented 2 months ago

I see. Thank you!