RobertTLange / gymnax

RL Environments in JAX 🌍
Apache License 2.0
575 stars 54 forks source link

Differentiate step function ? #26

Open dzako opened 1 year ago

dzako commented 1 year ago

Hello, is it possible to return the differential of the step reward function (with respect to the action) at least for the simplest envs like pendulum, cartple ? Best, Jacek

RobertTLange commented 1 year ago

Hi @dzako, thank for your kind words and appreciation. You are right, for now obs and state are wrapped with a stop gradient operation. While I agree that this is a desirable feature for certain environments there are two main considerations:

  1. Not all environment step transitions are differentiable. E.g. respawning in the minatar implementations is essentially a step function for certain pixel activations. Therefore this can't be a general feature.
  2. This can have subtle (or not so) downstream effects. While one may want to differentiate through step transitions in the context of model-based RL or control/MPC/etc., this can also cause problems for standard model-free RL pipelines (using JAX grad) which assume that the environment is not "accessible".

I will see if it makes sense to add a stop_gradient option when calling gymnax.make. Let me know if you have ideas/opinions and what your particular use case could be.

carlosgmartin commented 1 year ago

I think it makes sense to remove all stop_gradients from the environments themselves, so that RL algorithms downstream have the option to use those gradients if desired.

It seems to me like it is the downstream responsibility of an RL algorithm to impose a stop_gradient if they happen to require it.

dominikstrb commented 9 months ago

I just wanted to bump this issue, because I think it would be very useful to have the ability to differentiate through dynamics and observation function. This would allow us to use gymnax for the purpose of model-based control and for explicit modeling of partially observable environments.

janakact commented 7 months ago

+1 Yeah. This would be really nice feature. Does anyone know a library that offers a differentiable step function?

dominikstrb commented 7 months ago

@janakact

Does anyone know a library that offers a differentiable step function?

Shameless self-plug: I have a package for non-linear inverse optimal control that makes use of differentiable step functions. However, the environments are custom partially-observable stochastic environments and therfore do not completely correspond to standard environments from gym.

carlosgmartin commented 7 months ago

@janakact Some libraries you might want to look into:

Not sure which of them satisfy your criterion.