DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
9.1k stars 1.7k forks source link

[Bug] `DQN` typing wrongly states that `batch_size` can be `None` #789

Closed qgallouedec closed 2 years ago

qgallouedec commented 2 years ago

🐛 Bug

Typing suggests that batch_size for DQN can be None. But it can't be.

https://github.com/DLR-RM/stable-baselines3/blob/13fcb124711be4de86199a628acaedd215931756/stable_baselines3/dqn/dqn.py#L69

To Reproduce

import gym
from stable_baselines3 import DQN

model = DQN('MlpPolicy', 'MountainCar-v0', batch_size=None)
model.learn(60000)
Traceback (most recent call last):
  File "/Users/quentingallouedec/stable-baselines3/ex.py", line 6, in <module>
    model.learn(60000)
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/dqn/dqn.py", line 258, in learn
    return super(DQN, self).learn(
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/off_policy_algorithm.py", line 373, in learn
    self.train(batch_size=self.batch_size, gradient_steps=gradient_steps)
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/dqn/dqn.py", line 180, in train
    replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/buffers.py", line 278, in sample
    return super().sample(batch_size=batch_size, env=env)
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/buffers.py", line 110, in sample
    return self._get_samples(batch_inds, env=env)
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/buffers.py", line 289, in _get_samples
    env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),))
TypeError: object of type 'int' has no len()

I'm opening an PR

araffin commented 2 years ago

Hello, thanks for reporting the bug, batch_size must be an int indeed ;)