opendilab / DI-drive

Decision Intelligence Platform for Autonomous Driving simulation.
https://opendilab.github.io/DI-drive/
Apache License 2.0
569 stars 58 forks source link

Collect less frames in each episode #6

Closed weijielyu closed 3 years ago

weijielyu commented 3 years ago

Hi, I am collecting data for the Implicit Affordance model. I noticed that in each episode, it usually collects thousands of frames. I wonder if there's a way to downsample the number of frames collected in each episode, like only collect about a hundred frames in each episode? Thank you!

RobinC94 commented 3 years ago

Hi! I'm not sure the exact demand of 'downsample' the number of frames. Simply you can use a small suite (i.e. 'TurnTown01') to collect data. The number of frame will be small because the length of route is short. If you want to reduce the saved frames in collected data, generally there are two ways to do so. The simple one is to select less frame at equal intervals, for example:

data = collector.collect(n_episode)
for i in range(len(data)):
  data[i]['data'] = data[i]['data'][::10]

The complicated way is to skip frames in collector. You may need to go to the main loop in collector's collect method and store the frames at specific frames, like this:

# record frame count for all envs
frame_count = {env_id: 0 for env_id in range(self._env_num)}

while True:
    obs = self._env_manager.ready_obs
    policy_output = self._policy.forward(obs, **policy_kwargs)
    actions = {env_id: output['action'] for env_id, output in policy_output.items()}
    actions = to_ndarray(actions)
    for env_id in actions:
        self._obs_cache[env_id] = obs[env_id]
        self._actions_cache[env_id] = actions[env_id]
    timesteps = self._env_manager.step(actions)
    for env_id, timestep in timesteps.items():
        if timestep.info.get('abnormal', False):
            ...
        #########
        # store frame only if count % 10 == 0
        frame_count[env_id] += 1
        if frame_count[env_id] % 10 == 0
            transition = self._policy.process_transition(
                self._obs_cache[env_id], self._actions_cache[env_id], timestep
            )
            self._traj_cache[env_id].append(transition)
        ##########

        if timestep.done:
           ...
    if self._env_manager.done:
        break
weijielyu commented 3 years ago

It's a really detailed explanation! Thank you!