Closed williamshen-nz closed 2 years ago
Warning! We assume masking is supported for VecEnv ---------------------------------------- | rollout/ | | | ep_len_mean | 70 | | ep_rew_mean | -23.7 | | time/ | | | fps | 407 | | iterations | 102 | | time_elapsed | 8197 | | total_timesteps | 3342336 | | train/ | | | approx_kl | 0.14997472 | | clip_fraction | 0.598 | | clip_range | 0.2 | | entropy_loss | -3.58 | | explained_variance | 0.975 | | learning_rate | 0.0003 | | loss | -0.0384 | | n_updates | 1010 | | policy_gradient_loss | -0.0328 | | value_loss | 0.287 | ---------------------------------------- Traceback (most recent call last): File "cell2fire/rl_experiment_vectorized.py", line 288, in <module> train(args=parser.parse_args()) File "cell2fire/rl_experiment_vectorized.py", line 219, in train trainer.train() File "cell2fire/rl_experiment_vectorized.py", line 209, in train raise e File "cell2fire/rl_experiment_vectorized.py", line 204, in train model.learn( File "/home/gridsan/wshen/.local/lib/python3.8/site-packages/sb3_contrib/ppo_mask/ppo_mask.py", line 569, in learn self.train() File "/home/gridsan/wshen/.local/lib/python3.8/site-packages/sb3_contrib/ppo_mask/ppo_mask.py", line 429, in train values, log_prob, entropy = self.policy.evaluate_actions( File "/home/gridsan/wshen/.local/lib/python3.8/site-packages/sb3_contrib/common/maskable/policies.py", line 280, in evaluate_actions distribution.apply_masking(action_masks) File "/home/gridsan/wshen/.local/lib/python3.8/site-packages/sb3_contrib/common/maskable/distributions.py", line 152, in apply_masking self.distribution.apply_masking(masks) File "/home/gridsan/wshen/.local/lib/python3.8/site-packages/sb3_contrib/common/maskable/distributions.py", line 62, in apply_masking super().__init__(logits=logits) File "/state/partition1/llgrid/pkg/anaconda/anaconda3-2022a/lib/python3.8/site-packages/torch/distributions/categorical.py", line 64, in __init__ super(Categorical, self).__init__(batch_shape, validate_args=validate_args) File "/state/partition1/llgrid/pkg/anaconda/anaconda3-2022a/lib/python3.8/site-packages/torch/distributions/distribution.py", line 55, in __init__ raise ValueError( ValueError: Expected parameter probs (Tensor of shape (64, 1600)) of distribution MaskableCategorical(probs: torch.Size([64, 1600]), logits: torch.Size([64, 1600])) to satisfy the constraint Simplex(), but found invalid values: tensor([[2.7397e-07, 2.0007e-07, 6.2320e-07, ..., 2.7629e-07, 8.0404e-05, 7.5556e-07], [2.4421e-07, 3.3548e-07, 6.6760e-07, ..., 1.7289e-07, 2.6347e-06, 4.3772e-07], [1.4054e-07, 1.0664e-07, 3.6739e-07, ..., 1.5730e-07, 3.2728e-06, 4.0184e-07], ..., [3.9183e-07, 2.9925e-07, 3.5064e-07, ..., 2.0548e-07, 2.0982e-07, 5.6737e-07], [6.9835e-06, 4.1832e-06, 1.2083e-05, ..., 7.9078e-06, 4.3440e-04, 1.1752e-05], [1.5757e-06, 5.5615e-06, 4.7831e-07, ..., 6.5367e-07, 2.5431e-04, 5.1062e-07]], device='cuda:0', grad_fn=<SoftmaxBackward0>) 8 7 9 7
https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/master/sb3_contrib/common/maskable/distributions.py#L41-L65
Attempted fix here: https://github.com/williamshen-nz/stable-baselines3-contrib/commit/43b7fe6fcdda71c2c10de7af0d20f39e48ae5df3
Running on supercloud right now. Seems to have done the trick
https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/master/sb3_contrib/common/maskable/distributions.py#L41-L65