microsoft / qlib

Qlib is an AI-oriented quantitative investment platform that aims to realize the potential, empower research, and create value using AI technologies in quantitative investment, from exploring ideas to implementing productions. Qlib supports diverse machine learning modeling paradigms. including supervised learning, market dynamics modeling, and RL.
https://qlib.readthedocs.io/en/latest/
MIT License
15.11k stars 2.59k forks source link

reinforcement learning example give errors. #1611

Open quant2008 opened 1 year ago

quant2008 commented 1 year ago

I run reinforcement learning examples\rl\simple_example.ipynb and get a error at Training workflow section. here is the error. how to resolve it?

AssertionError Traceback (most recent call last) Cell In[7], line 24 18 vessel_kwargs = { 19 "update_kwargs": {"batch_size": 16, "repeat": 5}, 20 "episode_per_iter": 50, 21 } 23 print("Training started") ---> 24 train( 25 simulator_fn=lambda position: SimpleSimulator(position, NSTEPS), 26 state_interpreter=state_interpreter, 27 action_interpreter=action_interpreter, 28 policy=policy, 29 reward=reward, 30 initial_states=cast(List[float], SimpleDataset([10.0, 50.0, 100.0])), 31 trainer_kwargs=trainer_kwargs, 32 vessel_kwargs=vessel_kwargs, 33 ) 34 print("Training finished")

File e:\anaconda3\envs\qlib230510\lib\site-packages\qlib\rl\trainer\api.py:63, in train(simulator_fn, state_interpreter, action_interpreter, initial_states, policy, reward, vessel_kwargs, trainer_kwargs) 53 vessel = TrainingVessel( 54 simulator_fn=simulator_fn, 55 state_interpreter=state_interpreter, (...) 60 vessel_kwargs, 61 ) 62 trainer = Trainer(trainer_kwargs) ---> 63 trainer.fit(vessel)

File e:\anaconda3\envs\qlib230510\lib\site-packages\qlib\rl\trainer\trainer.py:224, in Trainer.fit(self, vessel, ckpt_path) 222 with _wrap_context(vessel.train_seed_iterator()) as iterator: 223 vector_env = self.venv_from_iterator(iterator) --> 224 self.vessel.train(vector_env) 225 del vector_env # FIXME: Explicitly delete this object to avoid memory leak. 227 self._call_callback_hooks("on_train_end")

File e:\anaconda3\envs\qlib230510\lib\site-packages\qlib\rl\trainer\vessel.py:171, in TrainingVessel.train(self, vector_env) 168 self.policy.train() 170 with vector_env.collector_guard(): --> 171 collector = Collector(self.policy, vector_env, VectorReplayBuffer(self.buffer_size, len(vector_env))) 173 # Number of episodes collected in each training iteration can be overridden by fast dev run. 174 if self.trainer.fast_dev_run is not None:

File e:\anaconda3\envs\qlib230510\lib\site-packages\tianshou\data\collector.py:79, in Collector.init(self, policy, env, buffer, preprocess_fn, exploration_noise) 77 self._action_space = self.env.action_space 78 # avoid creating attribute outside init ---> 79 self.reset(False)

File e:\anaconda3\envs\qlib230510\lib\site-packages\tianshou\data\collector.py:130, in Collector.reset(self, reset_buffer, gym_reset_kwargs) 117 # use empty Batch for "state" so that self.data supports slicing 118 # convert empty Batch to None when passing data to policy 119 self.data = Batch( 120 obs={}, 121 act={}, (...) 128 policy={} 129 ) --> 130 self.reset_env(gym_reset_kwargs) 131 if reset_buffer: 132 self.reset_buffer()

File e:\anaconda3\envs\qlib230510\lib\site-packages\tianshou\data\collector.py:146, in Collector.reset_env(self, gym_reset_kwargs) 144 """Reset all of the environments.""" 145 gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {} --> 146 obs, info = self.env.reset(**gym_reset_kwargs) 147 if self.preprocess_fn: 148 processed_data = self.preprocess_fn( 149 obs=obs, info=info, env_id=np.arange(self.env_num) 150 )

File e:\anaconda3\envs\qlib230510\lib\site-packages\qlib\rl\utils\finite_env.py:233, in FiniteVectorEnv.reset(self, id) 231 id2idx = {i: k for k, i in enumerate(wrapped_id)} 232 if request_id: --> 233 for i, o in zip(request_id, super().reset(request_id)): 234 obs[id2idx[i]] = self._postproc_env_obs(o) 236 for i, o in zip(wrapped_id, obs):

File e:\anaconda3\envs\qlib230510\lib\site-packages\tianshou\env\venvs.py:280, in BaseVectorEnv.reset(self, id, kwargs) 277 self.workers[i].send(None, kwargs) 278 ret_list = [self.workers[i].recv() for i in id] --> 280 assert ( 281 isinstance(ret_list[0], (tuple, list)) and len(ret_list[0]) == 2 282 and isinstance(ret_list[0][1], dict) 283 ) 285 obs_list = [r[0] for r in ret_list] 287 if isinstance(obs_list[0], tuple): # type: ignore

AssertionError:

quant2008 commented 1 year ago

solved. need downgrade tianshou