google / trajax

Apache License 2.0
186 stars 23 forks source link

Bugfixes and improvements suggestions #16

Open joaospinto opened 2 weeks ago

joaospinto commented 2 weeks ago

I recently implemented a multiple-shooting variant of iLQR here: https://github.com/joaospinto/primal_dual_ilqr

Much of my implementation (apart from core algorithmic changes) is inspired on trajax. While implementing this, I found a few issues/possible improvements related to trajax itself, which I'll list below.

1) In https://github.com/google/trajax/blob/c94a637c5a397b3d4100153f25b4b165507b5b20/trajax/optimizers.py#L798, you are projecting the Q matrices to become positive semi-definite. When the M matrices (i.e. cross-state-and-control quadratic terms) are non-zero, this is not sufficient. You should do this instead: https://github.com/joaospinto/primal_dual_ilqr/blob/main/primal_dual_ilqr/optimizers.py#L60

2) Perhaps because of the issue above, you resort to doing least-square solves in https://github.com/google/trajax/blob/c94a637c5a397b3d4100153f25b4b165507b5b20/trajax/tvlqr.py#L94, when you should be able to use Cholesky solves instead (didn't benchmark this part in JAX, but generally should be faster); see https://github.com/joaospinto/primal_dual_ilqr/blob/main/primal_dual_ilqr/primal_tvlqr.py#L38.

3) In https://github.com/google/trajax/blob/c94a637c5a397b3d4100153f25b4b165507b5b20/trajax/tvlqr.py#L61, some of the terms you have in your code are mathematically guaranteed to be zero and can be removed; this is solved in https://github.com/joaospinto/primal_dual_ilqr/blob/main/primal_dual_ilqr/primal_tvlqr.py#L38.

4) In places like https://github.com/google/trajax/blob/c94a637c5a397b3d4100153f25b4b165507b5b20/trajax/tvlqr.py#L57, you use fori_loop, when using a scan results in significant speed-ups. You can find almost-drop-in replacements here: https://github.com/joaospinto/primal_dual_ilqr/blob/main/primal_dual_ilqr/primal_tvlqr.py

5) It would be interesting to add support for a GPU-accelerated implementation of LQR; see https://github.com/joaospinto/primal_dual_ilqr/blob/main/primal_dual_ilqr/primal_tvlqr.py#L96.

I'd be happy to try to merge my code (https://github.com/joaospinto/primal_dual_ilqr) into this repository (either just these improvements, or actually adding the new algorithm itself), if you find that interesting. Let me know!