Open simonsays1980 opened 1 year ago
I should probably raise this as a separate issue, but _check_if_diag_gaussian
fails for MultiActionDistribution
s even though the child distribution will end up being DiagGaussian
s
I think _check_if_diag_gaussian
could handle this by instantiating a throw-away class (not ideal) or it could be handled in Distribution.get_partial_dist_cls
with class attributes somehow and a slight modification to the assert
statements in _check_if_diag_gaussian
.
At the time the partial distribution is constructed it is possible to know the types of the final distributions - perhaps make a class attribute or property flat_distribution_types: List[Type[Distribution]]
for DistributionPartial
and then check:
assert issubclass(action_distribution_cls, TorchDiagGaussian) or (issubclass(action_distribution_cls, DistributionPartial) and all(issubclass(x, TorchDiagGaussian) for x in action_distribution_cls.flat_distribution_types))
Of course, I am sure you will come up with something more clever.
Description
Right now the function is in
ppo_catalog.py
but will be used by manyRLModule
subclasses. Make the function available in a more central place likerllib.utils
.Use case