j96w / MimicPlay

"MimicPlay: Long-Horizon Imitation Learning by Watching Human Play" code repository
MIT License
214 stars 23 forks source link

High-Level Plan Training: Adapting Code for Human Play Data #8

Closed kadambaribhujbal closed 4 months ago

kadambaribhujbal commented 5 months ago

Hello, this is really impressive work. I have been able to replicate the simulation part of training the high and low level planner using play data. I am currently focusing on training a high-level latent plan using human play data. To this end, I have successfully generated an HDF5 file containing all the relevant information from the human play data extracted from the provided videos. However, I have encountered a hurdle in adapting the existing code for training the high level planner to utilize my generated HDF5 file. The current training code appears to be tailored for simulation data, leading to errors when I attempt to integrate my dataset. Any guidance or direction you could offer would be immensely helpful. Thank you!

destroy314 commented 5 months ago

Hello, my goals and progress are the same as yours. I've found that learning from human data requires some code modifications. After recording dual-view operation videos, the h5 file obtained by running demo_mp4.py still needs to extract the trajectory of the end-effector (here, the hand). Before that, modify the target key names in scripts/dataset_extract_traj_plans.py, and add a num_samples attribute,: (the lengths of 'done' and 'action' are one less than obs)

for i in range(0, DEMO_COUNT):
    # Extract the hand location data
    hand_loc = f[f'data/demo_{i}/obs/hand_loc'][...]

    # Calculate the future trajectory for each data point
    future_traj_data = np.array([get_future_points(hand_loc[j:]) for j in range(len(hand_loc))])

    # Create the new dataset for future trajectory
    f.create_dataset(f'data/demo_{i}/obs/hand_loc_future_traj', data=future_traj_data)

    # Add the 'num_samples' attribute to the demo (specifically for human play data)
    f[f'data/demo_{i}'].attrs['num_samples'] = len(hand_loc) - 1

Correspondingly, the high-level JSON configuration file also needs to have the observation key names modified:

"observation": {
    "modalities": {
        "obs": {
            "low_dim": [
                "hand_loc",
                "hand_loc_future_traj"
            ],
            "rgb": [
                "front_image_1",
                "front_image_2"
            ],
            "depth": [],
            "scan": []
        },
        "goal": {
            "low_dim": [],
            "rgb": [
                "front_image_1",
                "front_image_2"
            ],
            "depth": [],
            "scan": []
        }
    },
}

In addition, mimicplay/algo/mimicplay.py also needs the following corresponding lines modified to match the new observations (hand_loc_future_traj, front_image_1, and front_image_2):

del self.obs_shapes['robot0_eef_pos_future_traj']

batch["goal_obs"]["agentview_image"] = batch["goal_obs"]["agentview_image"][:, 0]

log_probs = dists.log_prob(batch["obs"]["hand_loc_future_traj"])

The end of last line should be ["hand_loc_future_traj"].reshape(-1,40), and I think it's also possible to modify the shape of the data recorded in dataset_extract_traj_plans.py. Also modify the dimension number in the corresponding JSON file to match the dim:

"highlevel": {
    "ac_dim": 40,
},

Since human play data doesn't have a train-validation split, set it not to validate:

"experiment": {
    "validate": false,
}
"train": {
    "hdf5_filter_key": null,
    "hdf5_validation_filter_key": null,
}

This could also be addressed by modifying demo_mp4.py and make a split. Given the amount of code involved, I guess renaming keys in human data to match those in the code might be a better solution. In any case, the above modifications were sufficient for me to train a high-level model (but no way to verify the correctness), and I hope I haven't missed anything!

Compared to play data provided in simulation environment, the model trained with my collected data saw the Train/Policy_Grad_Norms parameter begin to rise from a decline around 20 epochs, far less than the ~400 epochs for the former, which also coincided with the best validation results for the former. It would be helpful if the authors could provide guidance on data collection or training.

kadambaribhujbal commented 5 months ago

Hello, this was extremely helpful. Thank you! While I found the guidance immensely helpful, I'm still a bit confused with how to replace the line batch["goal_obs"]["agentview_image"] = batch["goal_obs"]["agentview_image"][:, 0]. Additionally, I am encountering an error. Any insights or assistance you could provide would be greatly appreciated. Thanks again!

PlaydataSequenceDataset (
    path=scripts/human_playdata_process/demo_hand_loc_1_new.hdf5
    obs_keys=('front_image_1', 'front_image_2', 'hand_loc', 'hand_loc_future_traj')
    seq_length=1
    filter_key=none
    frame_stack=1
    pad_seq_length=True
    pad_frame_stack=True
    goal_mode=nstep
    cache_mode=low_dim
    num_demos=1
    num_sequences=144
)
0%|          | 0/100 [00:00<?, ?it/s]
run failed with error:
Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [32, 1, 3, 76, 76] 
Traceback (most recent call last):
  File "scripts/train.py", line 376, in main
    train(config, device=device)
  File "scripts/train.py", line 199, in train
    step_log = TrainUtils.run_epoch(model=model, data_loader=train_loader, epoch=epoch, num_steps=train_num_steps)
  File "/home/interaction/kaddy/kaddy/robomimic/robomimic/utils/train_utils.py", line 559, in run_epoch
    info = model.train_on_batch(input_batch, epoch, validate=validate)
  File "/home/interaction/kaddy/kaddy/robomimic/robomimic/algo/bc.py", line 137, in train_on_batch
    predictions = self._forward_training(batch)
  File "/home/interaction/kaddy/kaddy/MimicPlay/mimicplay/algo/mimicplay.py", line 178, in _forward_training
    dists = self.nets["policy"].forward_train(
  File "/home/interaction/kaddy/kaddy/MimicPlay/mimicplay/models/policy_nets.py", line 228, in forward_train
    out = MIMO_MLP.forward(self, return_latent=return_latent, obs=obs_dict, goal=goal_dict)
  File "/home/interaction/kaddy/kaddy/MimicPlay/mimicplay/models/obs_nets.py", line 584, in forward
    enc_outputs = self.nets["encoder"](**inputs)
  File "/home/interaction/kaddy/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/interaction/kaddy/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/interaction/kaddy/kaddy/MimicPlay/mimicplay/models/obs_nets.py", line 445, in forward
    self.nets[obs_group].forward(inputs[obs_group])
  File "/home/interaction/kaddy/kaddy/MimicPlay/mimicplay/models/obs_nets.py", line 235, in forward
    x = self.obs_nets[k](x)
  File "/home/interaction/kaddy/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/interaction/kaddy/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/interaction/kaddy/kaddy/robomimic/robomimic/models/obs_core.py", line 172, in forward
    return super(VisualCore, self).forward(inputs)
  File "/home/interaction/kaddy/kaddy/robomimic/robomimic/models/base_nets.py", line 482, in forward
    x = self.nets(inputs)
  File "/home/interaction/kaddy/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/interaction/kaddy/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/interaction/kaddy/.local/lib/python3.8/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/home/interaction/kaddy/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/interaction/kaddy/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/interaction/kaddy/kaddy/robomimic/robomimic/models/base_nets.py", line 482, in forward
    x = self.nets(inputs)
  File "/home/interaction/kaddy/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/interaction/kaddy/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/interaction/kaddy/.local/lib/python3.8/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/home/interaction/kaddy/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/interaction/kaddy/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/interaction/kaddy/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 460, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/interaction/kaddy/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [32, 1, 3, 76, 76]

Also I wanted to confirm if there are any changes to be made in these sections of train.py

shape_meta = FileUtils.get_shape_metadata_from_dataset(
        dataset_path=config.train.data,
        all_obs_keys=config.all_obs_keys,
        verbose=True
    )

model = algo_factory(
        algo_name=config.algo_name,
        config=config,
        obs_key_shapes=shape_meta["all_shapes"],
        ac_dim=shape_meta["ac_dim"],
        device=device,
    )

# load training data
    trainset, validset = load_data_for_training(
        config, obs_keys=shape_meta["all_obs_keys"])
destroy314 commented 5 months ago
  1. I changed it to:

    batch["goal_obs"]["front_image_1"] = batch["goal_obs"]["front_image_1"][:, 0]
    batch["goal_obs"]["front_image_2"] = batch["goal_obs"]["front_image_2"][:, 0]
  2. I believe this was due to not making the change in 1, [:, 0] should remove the dimension of length 1.

  3. I didn't modify the part related to shape_meta, and it seems to be working fine:

    ============= Loaded Environment Metadata =============
    obs key front_image_1 with shape (128, 128, 3)
    obs key front_image_2 with shape (128, 128, 3)
    obs key hand_loc with shape (1, 4)
    obs key hand_loc_future_traj with shape (10, 4)

    But I indeed missed the part about disabling env_meta:

    # env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=config.train.data)
    env_meta=None
j96w commented 4 months ago

Thanks @destroy314 for the nice guidance. I'm searching among scripts in our original codebase and will update this repo soon after I find it.

j96w commented 4 months ago

Hi @destroy314 @kadambaribhujbal, we've updated the script for building your own human data and saved it in a hdf5 format (with env_args) that can be directly used for training a high-level policy:

python scripts/train.py --config configs/highlevel_human.json --dataset 'PATH_TO_HDF5_FILE

Feel free to check it out.