ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.34k stars 5.64k forks source link

Parametric Action Space crash with A2C [rllib] #7924

Closed jdhorwood closed 2 years ago

jdhorwood commented 4 years ago

What is the problem?

When running A2C using a custom model for masked actions, rllib crashes with the following error:

Traceback (most recent call last):
  File "random_env.py", line 140, in <module>
    results = trainer.train()
  File "[..]lib/python3.6/site-packages/ray/rllib/agents/trainer.py", line 505, in train
    raise e
  File "[..]lib/python3.6/site-packages/ray/rllib/agents/trainer.py", line 491, in train
    result = Trainable.train(self)
  File "[..]lib/python3.6/site-packages/ray/tune/trainable.py", line 261, in train
    result = self._train()
  File "[..]lib/python3.6/site-packages/ray/rllib/agents/trainer_template.py", line 142, in _train
    return self._train_exec_impl()
  File "[..]lib/python3.6/site-packages/ray/rllib/agents/trainer_template.py", line 174, in _train_exec_impl
    res = next(self.train_exec_impl)
  File "[..]lib/python3.6/site-packages/ray/util/iter.py", line 634, in __next__
    return next(self.built_iterator)
  File "[..]lib/python3.6/site-packages/ray/util/iter.py", line 644, in apply_foreach
    for item in it:
  File "[..]lib/python3.6/site-packages/ray/util/iter.py", line 685, in apply_filter
    for item in it:
  File "[..]lib/python3.6/site-packages/ray/util/iter.py", line 644, in apply_foreach
    for item in it:
  File "[..]lib/python3.6/site-packages/ray/util/iter.py", line 718, in apply_flatten
    for item in it:
  File "[..]lib/python3.6/site-packages/ray/util/iter.py", line 670, in add_wait_hooks
    item = next(it)
  File "[..]lib/python3.6/site-packages/ray/util/iter.py", line 644, in apply_foreach
    for item in it:
  File "[..]lib/python3.6/site-packages/ray/util/iter.py", line 644, in apply_foreach
    for item in it:
  File "[..]lib/python3.6/site-packages/ray/rllib/utils/experimental_dsl.py", line 110, in sampler
    yield workers.local_worker().sample()
  File "[..]lib/python3.6/site-packages/ray/rllib/evaluation/rollout_worker.py", line 492, in sample
    batches = [self.input_reader.next()]
  File "[..]lib/python3.6/site-packages/ray/rllib/evaluation/sampler.py", line 53, in next
    batches = [self.get_data()]
  File "[..]lib/python3.6/site-packages/ray/rllib/evaluation/sampler.py", line 96, in get_data
    item = next(self.rollout_provider)
  File "[..]lib/python3.6/site-packages/ray/rllib/evaluation/sampler.py", line 354, in _env_runner
    active_episodes)
  File "[..]lib/python3.6/site-packages/ray/rllib/evaluation/sampler.py", line 602, in _do_policy_eval
    timestep=policy.global_timestep)
  File "[..]lib/python3.6/site-packages/ray/rllib/policy/eager_tf_policy.py", line 58, in _func
    return func(*args, **kwargs)
  File "[..]lib/python3.6/site-packages/ray/rllib/policy/eager_tf_policy.py", line 66, in _func
    out = func(*args, **kwargs)
  File "[..]lib/python3.6/site-packages/ray/rllib/policy/eager_tf_policy.py", line 132, in compute_actions
    info_batch, episodes, explore, timestep, **kwargs)
  File "[..]lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 441, in __call__
    results = self._stateful_fn(*args, **kwds)
  File "[..]lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 1822, in __call__
    return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
  File "[..]lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 1141, in _filtered_call
    self.captured_inputs)
  File "[..]lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 1224, in _call_flat
    ctx, args, cancellation_manager=cancellation_manager)
  File "[..]lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 511, in call
    ctx=ctx)
  File "[..]lib/python3.6/site-packages/tensorflow_core/python/eager/execute.py", line 67, in quick_execute
    six.raise_from(core._status_to_exception(e.code, message), None)
  File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError:  Received a label value of 90 which is outside the valid range of [0, 90).  Label values: 90
     [[{{node cond_1/then/_9/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits}}]] [Op:__inference_compute_actions_578]

Function call stack:
compute_actions

Ray version and other system information (Python version, TensorFlow version, OS):

Python: 3.6.10 Ray: 0.8.4 Tensorflow: 1.15.0 OS: macOS Mojave

Reproduction (REQUIRED)

After some time, I was able to locate the trainer config parameter causing the error. It seems that setting entropy_coeff > 1 leads to this crash, while things run fine when the mask is less stable/removed, using say tf.float16.min, entropy_coeff < 1, or using PPO. I expect this issue would also occur with A3C. A script which reproduces the issue can be found here.

jdhorwood commented 4 years ago

As an added comment, it seems like there might be other issues surrounding entropy. For the problem I am working on, attempting to tune the entropy coefficient changes nothing, regardless of the coefficient value. When using Pytorch's A2C, the above crash does not occur, but optimization is identical whether entropy_coeff is set to 0.01 or 10.

jdhorwood commented 4 years ago

I'm not sure where this is in the code exactly, but my intuition for both these issues is the following:

This could be explained by entropy_coeff multiplying the logits prior to the policy's softmax, leading to nan's when these logits contain tf.float32.min, and further causing the above error message.

Using the entropy_coeff at this stage and returning the resulting value as the policy's entropy would additionally result in identical values regardless of entropy_coeff, which could explain the second observation.

stale[bot] commented 3 years ago

Hi, I'm a bot from the Ray team :)

To help human contributors to focus on more relevant issues, I will automatically add the stale label to issues that have had no activity for more than 4 months.

If there is no further activity in the 14 days, the issue will be closed!

You can always ask for help on our discussion forum or Ray's public slack channel.

jdhorwood commented 3 years ago

Hi, did anybody get a chance to look at this issue?

avnishn commented 2 years ago

going to close this for now, as it is stale and requires a repro script.