[x] Bug fix (non-breaking change which fixes an issue)
[x] New feature (non-breaking change which adds functionality)
[ ] Breaking change (fix or feature that would cause existing functionality to change)
Motivation and Context / Related issue
The current implementation of get_sequence_buffer_iterator() crashes with a the error TypeError: '>' not supported between instances of 'NoneType' and 'int' if the default value (None) is passed for the ensemble_size parameter. I changed the default value to 1 and changed the type of the parameter from Optional[int] to int. This fixes the problem and is also in line with the ensemble_size parameter of get_basic_buffer_iterators().
Furthermore, I added an assertion to get_sequence_buffer_iterator() that ensures that the replay buffer passed to the function stores trajectory information (replay_buffer.stores_trajectories is true). Without this assertion, the function just crashes with the error TypeError: object of type 'NoneType' has no len() if this is not the case. The assertion should make the problem more clear and easier to debug for users.
How Has This Been Tested (if it applies)
from mbrl.util.replay_buffer import ReplayBuffer
from mbrl.util.common import get_sequence_buffer_iterator
replay_buffer = ReplayBuffer(10, (1,), (1,), max_trajectory_length=1)
get_sequence_buffer_iterator(replay_buffer, 1, 0.0, 1)
Checklist
[x] The documentation is up-to-date with the changes I made.
[x] I have read the CONTRIBUTING document and completed the CLA (see CONTRIBUTING).
[x] All tests passed, and additional code has been covered with new tests.
Types of changes
Motivation and Context / Related issue
The current implementation of
get_sequence_buffer_iterator()
crashes with a the errorTypeError: '>' not supported between instances of 'NoneType' and 'int'
if the default value (None
) is passed for theensemble_size
parameter. I changed the default value to 1 and changed the type of the parameter fromOptional[int]
toint
. This fixes the problem and is also in line with theensemble_size
parameter ofget_basic_buffer_iterators()
.Furthermore, I added an assertion to
get_sequence_buffer_iterator()
that ensures that the replay buffer passed to the function stores trajectory information (replay_buffer.stores_trajectories
is true). Without this assertion, the function just crashes with the errorTypeError: object of type 'NoneType' has no len()
if this is not the case. The assertion should make the problem more clear and easier to debug for users.How Has This Been Tested (if it applies)
Checklist