ray-project / ray

Ray is an AI compute engine. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.75k stars 5.74k forks source link

[RLlib] Make `_check_if_diag_gaussian` available in the utils #39519

Open simonsays1980 opened 1 year ago

simonsays1980 commented 1 year ago

Description

Right now the function is in ppo_catalog.py but will be used by many RLModule subclasses. Make the function available in a more central place like rllib.utils.

Use case


from ray.rllib.utils.action_distributions import check_if_diag_gaussian
gresavage commented 1 year ago

I should probably raise this as a separate issue, but _check_if_diag_gaussian fails for MultiActionDistributions even though the child distribution will end up being DiagGaussians

I think _check_if_diag_gaussian could handle this by instantiating a throw-away class (not ideal) or it could be handled in Distribution.get_partial_dist_cls with class attributes somehow and a slight modification to the assert statements in _check_if_diag_gaussian.

At the time the partial distribution is constructed it is possible to know the types of the final distributions - perhaps make a class attribute or property flat_distribution_types: List[Type[Distribution]] for DistributionPartial and then check:

assert issubclass(action_distribution_cls, TorchDiagGaussian) or (issubclass(action_distribution_cls, DistributionPartial) and all(issubclass(x, TorchDiagGaussian) for x in action_distribution_cls.flat_distribution_types))

Of course, I am sure you will come up with something more clever.