haosulab / ManiSkill2-Learn

Apache License 2.0
77 stars 15 forks source link

Sampling Full Trajectories #27

Closed ErikKrauter closed 4 days ago

ErikKrauter commented 6 months ago

I store trajectories generated with my RL agent in a h5 file using the functionality from the evaluation.py script.

Now I want to train a network on those trajectories. It is crucial that the network receives full trajectories and not transitions randomly sampled from all trajectories. Can I achieve this functionality with the existing code base of ManiSkill2-Learn or do I need to add custom functionality?

ErikKrauter commented 6 months ago

Okay I believe, I found a way. I use the ReplayMemory class with the following config file to achieve the desired behavior. Setting the sampling strategu to TStepTransition and the horizon to -1 makes the replay buffer sample full trajectories. It also automatically pads trajectories that are shorter, by appending the first transition of the trajectory to it. By specifying capacity=-1 and num_samples=-1 the raplaz buffer loads the entire h5 file into memory.

replay_cfg = dict( type="ReplayMemory", sampling_cfg=dict(type="TStepTransition", horizon=-1), capacity=-1, num_samples=-1, keys=["obs", "actions", "dones", "episode_dones", "infos"], buffer_filenames=[ "Evaluation/EvalDebugTraj/test/trajectory.h5", ], )

I have a follow up question: Does the logic still work when I specify dynamic_loading=True?

xuanlinli17 commented 6 months ago

Yes, you need TStepTransition and horizon=-1 to sample full trajectories.

For dynamic loading, you need to set a capacity > 0. Same logic should still hold, though if I remember correctly, capacity is now the number of trajectories.

Actually I think this is briefly mentioned in the readme:

demo_replay_cfg=dict(
    type="ReplayMemory",
    capacity=int(2e4),
    num_samples=-1,
    cache_size=int(2e4),
    dynamic_loading=True,
    synchronized=False,
    keys=["obs", "actions", "dones", "episode_dones"],
    buffer_filenames=[
        "PATH_TO_DEMO.h5",
    ],
),