alex-petrenko / sample-factory

High throughput synchronous and asynchronous reinforcement learning
https://samplefactory.dev
MIT License
811 stars 109 forks source link

Rare exception calculating max KL with 100% invalid samples #251

Closed edbeeching closed 1 year ago

edbeeching commented 1 year ago

There is a rare exception when a batch contains 100% invalid samples and we try to calculate the max KL. I have only seen this when using PBT

[2022-12-14 16:51:46,226][535381] Updated weights for policy 0, policy_version 75540 (0.0012) [2022-12-14 16:51:46,228][535383] Updated weights for policy 2, policy_version 71329 (0.0015) [2022-12-14 16:51:46,662][535296] Learner 3 replacing cfg parameter 'value_loss_coeff' with new value 0.38879753899815916 [2022-12-14 16:51:46,662][535296] Learner 3 replacing cfg parameter 'ppo_clip_ratio' with new value 0.26001595106852216 [2022-12-14 16:51:46,663][535296] Optimizer lr value 0.0000171, betas: (0.9, 0.999) [2022-12-14 16:51:46,664][535296] Loading state from checkpoint /gpfsssd/scratch/rech/ajs/utv52ia/godot_rl/godot_rl_agents/train_dir/20221214a_pbt_fps_hyp_scan_19/checkpoint_p1/checkpoint_000063520_1626112.pth... [2022-12-14 16:51:46,760][535296] Loading model from checkpoint [2022-12-14 16:51:46,767][535296] Loaded experiment state at self.train_step=78400, self.env_steps=2007040 [2022-12-14 16:51:46,774][535296] self.policy_id=3 batch has 100.00% of invalid samples [2022-12-14 16:51:46,774][535296] No valid samples in the batch, with PBT this must mean we just replaced weights [2022-12-14 16:51:46,783][536359] Updated weights for policy 3, policy_version 79401 (0.0012) [2022-12-14 16:51:46,996][535381] Updated weights for policy 0, policy_version 75588 (0.0013) [2022-12-14 16:51:47,006][536226] Updated weights for policy 1, policy_version 63583 (0.0011) [2022-12-14 16:51:47,007][535383] Updated weights for policy 2, policy_version 71378 (0.0018) [2022-12-14 16:51:47,266][535383] Updated weights for policy 2, policy_version 71391 (0.0011) [2022-12-14 16:51:47,309][536226] Updated weights for policy 1, policy_version 63597 (0.0015) [2022-12-14 16:51:47,564][535381] Updated weights for policy 0, policy_version 75602 (0.0012) [2022-12-14 16:51:47,600][534903] Fps is (10 sec: 4913.4, 60 sec: 4607.7, 300 sec: 4658.3). Total num frames: 7399424. Throughput: 0: 1207.7, 1: 838.9, 2: 1111.6, 3: 1397.7. Samples: 7387668. Policy #0 lag: (min: 101.0, avg: 341.4, max: 608.0) [2022-12-14 16:51:47,601][534903] Avg episode reward: [(0, '2.290'), (1, '10.690'), (2, '0.460'), (3, '0.320')] [2022-12-14 16:51:47,619][535296] self.policy_id=3 batch has 97.22% of invalid samples [2022-12-14 16:51:47,744][535296] EvtLoop [learner_proc3_evt_loop, process=learner_proc3] unhandled exception in slot='on_new_training_batch' connected to emitter=Emitter(object_id='Batcher_3', signal_name='training_batches_available'), args=(1,) Traceback (most recent call last): File "/gpfsssd/scratch/rech/ajs/utv52ia/godot_rl/godot_rl_agents/venv/lib/python3.8/site-packages/signal_slot/signal_slot.py", line 355, in _process_signal slot_callable(*args) File "/gpfsssd/scratch/rech/ajs/utv52ia/godot_rl/godot_rl_agents/venv/lib/python3.8/site-packages/sample_factory/algo/learning/learner_worker.py", line 150, in on_new_training_batch stats = self.learner.train(self.batcher.training_batches[batch_idx]) File "/gpfsssd/scratch/rech/ajs/utv52ia/godot_rl/godot_rl_agents/venv/lib/python3.8/site-packages/sample_factory/algo/learning/learner.py", line 1034, in train train_stats = self._train(buff, self.cfg.batch_size, experience_size, num_invalids) File "/gpfsssd/scratch/rech/ajs/utv52ia/godot_rl/godot_rl_agents/venv/lib/python3.8/site-packages/sample_factory/algo/learning/learner.py", line 761, in _train if kl_old.max().item() > 100: RuntimeError: max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument. [2022-12-14 16:51:47,784][535296] Unhandled exception max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument. in evt loop learner_proc3_evt_loop [2022-12-14 16:51:47,793][535381] Updated weights for policy 0, policy_version 75613 (0.0020) [2022-12-14 16:51:47,915][535383] Updated weights for policy 2, policy_version 71404 (0.0013) [2022-12-14 16:51:48,013][535381] Updated weights for policy 0, policy_version 75624 (0.0012) [2022-12-14 16:51:48,154][535383] Updated weights for policy 2, policy_version 71416 (0.0012)

edbeeching commented 1 year ago

fixed in #250