aidan-curtis / firehose

We are trying to put out big fires
GNU General Public License v3.0
3 stars 2 forks source link

Logits don't satisfy Simplex #12

Closed williamshen-nz closed 2 years ago

williamshen-nz commented 2 years ago
Warning! We assume masking is supported for VecEnv
----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 70         |
|    ep_rew_mean          | -23.7      |
| time/                   |            |
|    fps                  | 407        |
|    iterations           | 102        |
|    time_elapsed         | 8197       |
|    total_timesteps      | 3342336    |
| train/                  |            |
|    approx_kl            | 0.14997472 |
|    clip_fraction        | 0.598      |
|    clip_range           | 0.2        |
|    entropy_loss         | -3.58      |
|    explained_variance   | 0.975      |
|    learning_rate        | 0.0003     |
|    loss                 | -0.0384    |
|    n_updates            | 1010       |
|    policy_gradient_loss | -0.0328    |
|    value_loss           | 0.287      |
----------------------------------------
Traceback (most recent call last):
  File "cell2fire/rl_experiment_vectorized.py", line 288, in <module>
    train(args=parser.parse_args())
  File "cell2fire/rl_experiment_vectorized.py", line 219, in train
    trainer.train()
  File "cell2fire/rl_experiment_vectorized.py", line 209, in train
    raise e
  File "cell2fire/rl_experiment_vectorized.py", line 204, in train
    model.learn(
  File "/home/gridsan/wshen/.local/lib/python3.8/site-packages/sb3_contrib/ppo_mask/ppo_mask.py", line 569, in learn
    self.train()
  File "/home/gridsan/wshen/.local/lib/python3.8/site-packages/sb3_contrib/ppo_mask/ppo_mask.py", line 429, in train
    values, log_prob, entropy = self.policy.evaluate_actions(
  File "/home/gridsan/wshen/.local/lib/python3.8/site-packages/sb3_contrib/common/maskable/policies.py", line 280, in evaluate_actions
    distribution.apply_masking(action_masks)
  File "/home/gridsan/wshen/.local/lib/python3.8/site-packages/sb3_contrib/common/maskable/distributions.py", line 152, in apply_masking
    self.distribution.apply_masking(masks)
  File "/home/gridsan/wshen/.local/lib/python3.8/site-packages/sb3_contrib/common/maskable/distributions.py", line 62, in apply_masking
    super().__init__(logits=logits)
  File "/state/partition1/llgrid/pkg/anaconda/anaconda3-2022a/lib/python3.8/site-packages/torch/distributions/categorical.py", line 64, in __init__
    super(Categorical, self).__init__(batch_shape, validate_args=validate_args)
  File "/state/partition1/llgrid/pkg/anaconda/anaconda3-2022a/lib/python3.8/site-packages/torch/distributions/distribution.py", line 55, in __init__
    raise ValueError(
ValueError: Expected parameter probs (Tensor of shape (64, 1600)) of distribution MaskableCategorical(probs: torch.Size([64, 1600]), logits: torch.Size([64, 1600])) to satisfy the constraint Simplex(), but found invalid values:
tensor([[2.7397e-07, 2.0007e-07, 6.2320e-07,  ..., 2.7629e-07, 8.0404e-05,
         7.5556e-07],
        [2.4421e-07, 3.3548e-07, 6.6760e-07,  ..., 1.7289e-07, 2.6347e-06,
         4.3772e-07],
        [1.4054e-07, 1.0664e-07, 3.6739e-07,  ..., 1.5730e-07, 3.2728e-06,
         4.0184e-07],
        ...,
        [3.9183e-07, 2.9925e-07, 3.5064e-07,  ..., 2.0548e-07, 2.0982e-07,
         5.6737e-07],
        [6.9835e-06, 4.1832e-06, 1.2083e-05,  ..., 7.9078e-06, 4.3440e-04,
         1.1752e-05],
        [1.5757e-06, 5.5615e-06, 4.7831e-07,  ..., 6.5367e-07, 2.5431e-04,
         5.1062e-07]], device='cuda:0', grad_fn=<SoftmaxBackward0>)
8 7
9
7

https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/master/sb3_contrib/common/maskable/distributions.py#L41-L65

williamshen-nz commented 2 years ago

Attempted fix here: https://github.com/williamshen-nz/stable-baselines3-contrib/commit/43b7fe6fcdda71c2c10de7af0d20f39e48ae5df3

Running on supercloud right now. Seems to have done the trick