Closed LemonPi closed 4 years ago
Ah thanks for pointing that out! I was using it at some point for slightly faster solves but then changed the code around a bit and stopped using it. You're interested in using it just for the potential performance gains, right?
Most of the operations in the forward pass are already detached from the PyTorch computation graph as it's finding the fixed point and this option would just disconnect the final LQR call from the compute graph as well -- you'll probably get some gain but it probably won't be too significant, especially since the forward call to that LQR Function
isn't traced and it has a manually implemented backward pass.
It'd be nice for performance gains, but also it could relax the requirements on the input dynamics (to be more than LinDx, Module, or Function). I have a dynamics function/class that implements its own linearization (which is differentiable; I think the implementation is also differentiable, but not tested).
Ah, if I remember correctly (has been a while since I've gone through this code), I think you can use a general dynamics like that as long as you implement the linearization in a grad_input
method -- which may be a relatively lightweight Module
wrapper on top of your (numpy?) code
Ah ok; I saw the grad_input
method but thought we would have to differentiate through grad_input.
(unrelated quick question: why do you have u_init[-2] = u_init[-3]
in the pendulum example after taking an action?)
That is pretty strange, I'm not sure. I may have been trying to debug something in a weird way and accidentally left that in
Currently there is a backprop option that defaults to
True
that does nothing. It would be good if this actually controlled for whether gradients are computed throughout. My intended use case is replacing an iLQR implementation with one that handles control constraints. If differentiability is too baked into the code, is there an alternative control-constrained (box constraints are sufficient) DP based control implementation you would recommend (numpy or pytorch are fine)?