Chemellia / AtomicGraphNets.jl

Atomic graph models for molecules and crystals in Julia
MIT License
61 stars 10 forks source link

Make DEQ layer usable #37

Open rkurchin opened 3 years ago

rkurchin commented 3 years ago

We merged it in because it technically works, but forward pass is 80-90 minutes and backwards is several hours...would be awesome to get this to a point where it can actually be used. I'm not totally clear what is needed to make this happen.

Also because of aforementioned slowness there are no layer tests for it right now, so that should eventually be addressed too...

ChrisRackauckas commented 3 years ago

Yup I'll copy over the discussion on that:

That said, the DEQ algorithm as done right now is memory efficient but not compute efficient. I have some ideas, a few mentioned but not detailed in the original paper and a few not mentioned, which would drastically decrease the compute times.

One thing I found interesting was that Newton methods were by far the most efficient for the forward pass. You can see the ruminants of how to use Anderson acceleration: that was a no. And DynamicSS was also a big no, both taking around 5,000 seconds for a forward pass vs 1,800 of Newton. Probably quasi-Newton would be best to avoid the LU costs: the LU was by far the dominant cost.

For the reverse pass the autodiff = false is a bit of a red herring: even with true it's direct construction of the Jacobian and inversion. Numerical vs forward diff wouldn't reduce the computational complexity there. From the paper and the Python code (10) isn't a great way to do it because it makes a ton of assumptions which are violated by our minibatching. What they mention as (11) which is also described in the ODE section of https://youtu.be/KCTfPyVIxpc (and you can see the adjoint of a nonlinear solve, i.e. what is also derived in the DEQ paper but is a much older result. Best source is SGJ's teaching materials: https://math.mit.edu/~stevenj/18.336/adjoint.pdf) is that the nonlinear solve can be done via a Krylov subspace method. Specifically, instead of doing that Newton solve directly, we can use the vjp of f in a GMRES to handle that much more effectively. But I'll leave that for another time.

My general assessment is that, as is, this method is very memory efficient but not compute efficient, but there is a pretty major change that can be done to improve the training performance which I'll see if I can take up as a research project. But at least this is "usable" now.