entity-neural-network / incubator

Collection of in-progress libraries for entity neural networks.
Apache License 2.0
29 stars 10 forks source link

Crash when action mask prevents all actions #186

Closed cswinter closed 2 years ago

cswinter commented 2 years ago

When an action mask prevents all actions for an actor, we currently get a fairly inscrutable error message:

Traceback (most recent call last):
  File "/home/costa/Documents/go/src/github.com/vwxyzjn/incubator/enn_ppo/enn_ppo/train.py", line 1017, in <module>
    main()
  File "/home/costa/.cache/pypoetry/virtualenvs/incubator-BZ1MCGMS-py3.9/lib/python3.9/site-packages/hyperstate/command.py", line 86, in _f
    return f(cfg)
  File "/home/costa/Documents/go/src/github.com/vwxyzjn/incubator/enn_ppo/enn_ppo/train.py", line 1013, in main
    train(cfg)
  File "/home/costa/Documents/go/src/github.com/vwxyzjn/incubator/enn_ppo/enn_ppo/train.py", line 753, in train
    next_obs, next_done, metrics = rollout.run(
  File "/home/costa/Documents/go/src/github.com/vwxyzjn/incubator/enn_ppo/enn_ppo/train.py", line 367, in run
    ) = self.agent.get_action_and_auxiliary(
  File "/home/costa/Documents/go/src/github.com/vwxyzjn/incubator/rogue_net/rogue_net/actor.py", line 168, in get_action_and_auxiliary
    action, count, logprob, entropy, logit = action_head(
  File "/home/costa/.cache/pypoetry/virtualenvs/incubator-BZ1MCGMS-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/costa/Documents/go/src/github.com/vwxyzjn/incubator/rogue_net/rogue_net/head_creator.py", line 85, in forward
    dist = Categorical(logits=logits)
  File "/home/costa/.cache/pypoetry/virtualenvs/incubator-BZ1MCGMS-py3.9/lib/python3.9/site-packages/torch/distributions/categorical.py", line 64, in __init__
    super(Categorical, self).__init__(batch_shape, validate_args=validate_args)
  File "/home/costa/.cache/pypoetry/virtualenvs/incubator-BZ1MCGMS-py3.9/lib/python3.9/site-packages/torch/distributions/distribution.py", line 55, in __init__
    raise ValueError(
ValueError: Expected parameter logits (Tensor of shape (10, 12)) of distribution Categorical(logits: torch.Size([10, 12])) to satisfy the constraint IndependentConstraint(Real(), 1), but found invalid values:
tensor([[   -inf, -0.4215, -1.0674,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf,    -inf,    -inf],
        [   -inf, -0.3849, -1.1411,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf,    -inf,    -inf],
        [-0.1876, -2.1427, -2.9235,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf,    -inf,    -inf],
        [   -inf, -0.4179, -1.0742,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf,    -inf,    -inf],
        [   -inf, -0.3849, -1.1411,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf,    -inf,    -inf],
        [    nan,     nan,     nan,     nan,     nan,     nan,     nan,     nan,
             nan,     nan,     nan,     nan],
        [   -inf,    -inf, -6.3404, -3.5986,    -inf,    -inf,    -inf,    -inf,
            -inf, -0.0296,    -inf,    -inf],
        [   -inf,  0.0000,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf,    -inf,    -inf],
        [-3.6184,    -inf, -6.3346,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf, -0.0290,    -inf,    -inf],
        [   -inf, -0.4293, -1.0526,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf,    -inf,    -inf]], device='cuda:0')

The issue is the row with all nan which caused by a row of logits which all set to -inf by the mask.

It's slightly unclear what the best way of handling this is. When all actions are masked out, no action is valid, so just returning a random action could break the environment. Here's the main options I can think of:

  1. Disallow this, but report a better error closer to the source.
  2. Filter out any actors that have an action mask of all False. Advantage is that things "just work", disadvantage is performance impact (though, this is probably fine since we don't expect Environment to be maximally efficient), hiding potential logic error in environment implementation, and that the number of actions given to the act method could be surprising and violate some assumption since it won't match the number of actor_ids specified by the environment.

@vwxyzjn what are your thoughts on what the ideal API would do?

vwxyzjn commented 2 years ago

Thanks for looking into it. Oh this is interesting. I know the reason: in gym-microrts I determine if an actor is available to see if the actor is "busy" (not executing any actions at the moment), but it is possible to have situations where the actor is not busy but also no actions are available. For example, the barrack may not be busy but still has not enough money to produce any units.

I am ok with Option 1 but a bit lean towards Option 2 with a warning. Don't have a strong opinion on this though.

vwxyzjn commented 2 years ago

Actually maybe for the simplicity option 1 sounds more desirable :)

cswinter commented 2 years ago

Actually maybe for the simplicity option 1 sounds more desirable :)

Implementation wise, it wouldn't be that complicated, just needs one loop in the __post_init__ method of Observation