rll / rllab

rllab is a framework for developing and evaluating reinforcement learning algorithms, fully compatible with OpenAI Gym.
Other
2.91k stars 801 forks source link

Type mismatch in np.random.choice(...) call #224

Open riturajkaushik opened 6 years ago

riturajkaushik commented 6 years ago

It throws the following error in line 146 of optimizers/conjugate_gradient_optimizer.py of rllab:py2 branch while trying to run trpo_cartpole.py in that branch.

Traceback (most recent call last):
  File "trpo_cartpole.py", line 27, in <module>
    algo.train()
  File "/home/rkaushik/projects/cloned_libs/rllab/rllab/algos/batch_polopt.py", line 253, in train
    self.optimize_policy(itr, samples_data)
  File "/home/rkaushik/projects/cloned_libs/rllab/rllab/algos/npo.py", line 109, in optimize_policy
    self.optimizer.optimize(all_input_values)
  File "/home/rkaushik/projects/cloned_libs/rllab/rllab/optimizers/conjugate_gradient_optimizer.py", line 146, in optimize
    n_samples, (n_samples * self._subsample_factor), replace=False)
  File "mtrand.pyx", line 1176, in mtrand.RandomState.choice (numpy/random/mtrand/mtrand.c:18822)
TypeError: 'float' object cannot be interpreted as an index 

This can be resolved by typecasting (to int) the 2nd parameter np.random.choice(...) call like this.

inds = np.random.choice(
                n_samples, int(n_samples * self._subsample_factor), replace=False)

The typecasting is there in the branch master.