facebookresearch / mbrl-lib

Library for Model Based RL
MIT License
959 stars 158 forks source link

Fixed a bug where a default argument to get_sequence_buffer_iterator causes a crash and added an assertion for clarity #106

Closed jan1854 closed 3 years ago

jan1854 commented 3 years ago

Types of changes

Motivation and Context / Related issue

The current implementation of get_sequence_buffer_iterator() crashes with a the error TypeError: '>' not supported between instances of 'NoneType' and 'int' if the default value (None) is passed for the ensemble_size parameter. I changed the default value to 1 and changed the type of the parameter from Optional[int] to int. This fixes the problem and is also in line with the ensemble_size parameter of get_basic_buffer_iterators().

Furthermore, I added an assertion to get_sequence_buffer_iterator() that ensures that the replay buffer passed to the function stores trajectory information (replay_buffer.stores_trajectories is true). Without this assertion, the function just crashes with the error TypeError: object of type 'NoneType' has no len() if this is not the case. The assertion should make the problem more clear and easier to debug for users.

How Has This Been Tested (if it applies)

from mbrl.util.replay_buffer import ReplayBuffer
from mbrl.util.common import get_sequence_buffer_iterator

replay_buffer = ReplayBuffer(10, (1,), (1,), max_trajectory_length=1)
get_sequence_buffer_iterator(replay_buffer, 1, 0.0, 1)

Checklist

luisenp commented 3 years ago

LGTM, approved and thanks!