locuslab / mpc.pytorch

A fast and differentiable model predictive control (MPC) solver for PyTorch.
https://locuslab.github.io/mpc.pytorch/
MIT License
872 stars 146 forks source link

Use backprop option to control differentiability #24

Closed LemonPi closed 4 years ago

LemonPi commented 4 years ago

Currently there is a backprop option that defaults to True that does nothing. It would be good if this actually controlled for whether gradients are computed throughout. My intended use case is replacing an iLQR implementation with one that handles control constraints. If differentiability is too baked into the code, is there an alternative control-constrained (box constraints are sufficient) DP based control implementation you would recommend (numpy or pytorch are fine)?

bamos commented 4 years ago

Ah thanks for pointing that out! I was using it at some point for slightly faster solves but then changed the code around a bit and stopped using it. You're interested in using it just for the potential performance gains, right?

Most of the operations in the forward pass are already detached from the PyTorch computation graph as it's finding the fixed point and this option would just disconnect the final LQR call from the compute graph as well -- you'll probably get some gain but it probably won't be too significant, especially since the forward call to that LQR Function isn't traced and it has a manually implemented backward pass.

LemonPi commented 4 years ago

It'd be nice for performance gains, but also it could relax the requirements on the input dynamics (to be more than LinDx, Module, or Function). I have a dynamics function/class that implements its own linearization (which is differentiable; I think the implementation is also differentiable, but not tested).

bamos commented 4 years ago

Ah, if I remember correctly (has been a while since I've gone through this code), I think you can use a general dynamics like that as long as you implement the linearization in a grad_input method -- which may be a relatively lightweight Module wrapper on top of your (numpy?) code

LemonPi commented 4 years ago

Ah ok; I saw the grad_input method but thought we would have to differentiate through grad_input. (unrelated quick question: why do you have u_init[-2] = u_init[-3] in the pendulum example after taking an action?)

bamos commented 4 years ago

That is pretty strange, I'm not sure. I may have been trying to debug something in a weird way and accidentally left that in