IntelLabs / coach

Reinforcement Learning Coach by Intel AI Lab enables easy experimentation with state of the art Reinforcement Learning algorithms
https://intellabs.github.io/coach/
Apache License 2.0
2.32k stars 460 forks source link

IndexError: List Index out of range on custom csv file and space #437

Closed OGordon100 closed 4 years ago

OGordon100 commented 4 years ago

Following the batch reinforcement tutorial at https://github.com/NervanaSystems/coach/blob/master/tutorials/4.%20Batch%20Reinforcement%20Learning.ipynb, I am trying to train off my own dataset. To simplify the issue, I am still working with the cartpole-v1 environment. I have also replaced the env_params line with

from gym.wrappers import FrameStack env_params = GymEnvironment(LevelSelection("CartPole-v1"), frame_skip=1, visualization_parameters=VisualizationParameters()) env_params.env = FrameStack(env_params.env, 2)

This trains successfully.

However, when I replace acrobat.csv with my custom CSV here, I start to have issues. Modifying the SpacesDefinition to match the new CSV with

spaces = SpacesDefinition(state=StateSpace({'observation': VectorObservationSpace(shape=17)}), goal=None, action=DiscreteActionSpace(24), reward=RewardSpace(1))

and running, I find that it starts to train, then crashes after reward_model_num_epochs number of epochs (10 in this example).

2020-02-03-11:09:10.128020 Training Batch RL Models - Epoch: 0 Reward Model Loss: 78.19110613567074 2020-02-03-11:09:10.297832 Training Batch RL Models - Epoch: 1 Reward Model Loss: 76.49519771704414 2020-02-03-11:09:10.470471 Training Batch RL Models - Epoch: 2 Reward Model Loss: 74.02651463596109 2020-02-03-11:09:10.644748 Training Batch RL Models - Epoch: 3 Reward Model Loss: 70.41832784371371 2020-02-03-11:09:10.806664 Training Batch RL Models - Epoch: 4 Reward Model Loss: 65.34890014790214 2020-02-03-11:09:10.974298 Training Batch RL Models - Epoch: 5 Reward Model Loss: 58.735991261977354 2020-02-03-11:09:11.129985 Training Batch RL Models - Epoch: 6 Reward Model Loss: 50.830035138102495 2020-02-03-11:09:11.291238 Training Batch RL Models - Epoch: 7 Reward Model Loss: 42.01591932981272 2020-02-03-11:09:11.458220 Training Batch RL Models - Epoch: 8 Reward Model Loss: 32.86649761133493 2020-02-03-11:09:11.632883 Training Batch RL Models - Epoch: 9 Reward Model Loss: 24.055893288146052 Traceback (most recent call last): File "<input>", line 1, in <module> File "/home/mltest1/.pycharm_helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile pydev_imports.execfile(filename, global_vars, local_vars) # execute the script File "/home/mltest1/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile exec(compile(contents+"\n", file, 'exec'), glob, loc) File "/home/mltest1/tmp/pycharm_project_510/test_rl_coach.py", line 96, in <module> graph_manager.improve() File "/home/mltest1/anaconda3/envs/Oli1/lib/python3.6/site-packages/rl_coach/graph_managers/batch_rl_graph_manager.py", line 215, in improve self.initialize_ope_models_and_stats() File "/home/mltest1/anaconda3/envs/Oli1/lib/python3.6/site-packages/rl_coach/graph_managers/batch_rl_graph_manager.py", line 260, in initialize_ope_models_and_stats agent.improve_reward_model(epochs=self.reward_model_num_epochs) File "/home/mltest1/anaconda3/envs/Oli1/lib/python3.6/site-packages/rl_coach/agents/ddqn_bcq_agent.py", line 203, in improve_reward_model state_embeddings = self.embedding([transition.state for transition in self.memory.transitions File "/home/mltest1/anaconda3/envs/Oli1/lib/python3.6/site-packages/rl_coach/agents/ddqn_bcq_agent.py", line 89, in to_embedding states = self.prepare_batch_for_inference(states, 'reward_model') File "/home/mltest1/anaconda3/envs/Oli1/lib/python3.6/site-packages/rl_coach/agents/agent.py", line 814, in prepare_batch_for_inference if key in states[0].keys(): IndexError: list index out of range

I am using python 3.6.3 on Ubuntu 18.04 with Tensorflow 1.11, and the latest version of Coach.

OGordon100 commented 4 years ago

Figured it out - every action must be input at least once in the csv file.

I will open a pull request to add error text to explain this.