fancompute / wavetorch

🌊 Numerically solving and backpropagating through the wave equation
https://advances.sciencemag.org/content/5/12/eaay6946
MIT License
518 stars 83 forks source link

Implementing an adjoint calculation for backprop-ing through time #1

Open ianwilliamson opened 5 years ago

ianwilliamson commented 5 years ago

Should consider the performance benefit of implementing an adjoint calculation for the backward pass through the forward() method in WaveCell. This would potentially save us on memory during gradient computation because pytorch doesn't need to construct as large of a graph.

The approach is described here: https://pytorch.org/docs/stable/notes/extending.html

parenthetical-e commented 5 years ago

Sorry to pop in, but on the off and maybe small chance you folks haven’t seen this lib/paper:

https://github.com/rtqichen/torchdiffeq https://arxiv.org/pdf/1806.07366.pdf

Implements a ODE solver and uses adjoint methods for the backward pass. This is what you need?

I was already thinking about porting WaveCell to it for my own use. Collaborate?

ianwilliamson commented 5 years ago

Thanks for your interest in this! We are aware of that paper, but unfortunately we can't apply the scheme they propose here because the wave equation with loss (from the absorbing layer) is not reversible.

The "adjoint calculation" I'm referring to here is basically just hard coding the gradient for the time step using the pytorch API documented here: https://pytorch.org/docs/stable/notes/extending.html The motivation for this is that we can potentially save a bunch of memory because pytorch doesn't need to store the fields at every sub-operation of each time step. However, it still needs to store the fields at each time step (there's no getting around this when the differential equation isn't reversible.) In contrast, the neural ODE paper reconstructs these fields by reversing the forward equation during backpropagation, thus, avoiding the need to store the fields from the forward pass.

We actually have this adjoint approach implemented, I just need to push the commits to this repository.

ianwilliamson commented 5 years ago

I'm definitely interested to learn about your project and what you hope to do. We would certainly be open to collaboration if there's an opportunity.

parenthetical-e commented 5 years ago

Thanks for your interest in this! We are aware of that paper, but unfortunately we can't apply the scheme they propose here because the wave equation with loss (from the absorbing layer) is not reversible.

The "adjoint calculation" I'm referring to here is basically just hard coding the gradient for the time step using the pytorch API documented here: https://pytorch.org/docs/stable/notes/extending.html The motivation for this is that we can potentially save a bunch of memory because pytorch doesn't need to store the fields at every sub-operation of each time step. However, it still needs to store the fields at each time step (there's no getting around this when the differential equation isn't reversible.) In contrast, the neural ODE paper reconstructs these fields by reversing the forward equation during backpropagation, thus, avoiding the need to store the fields from the forward pass.

Ah. I understand. Thanks for the explanation.

parenthetical-e commented 5 years ago

I sent you an email about the project I'm pondering. :)

twhughes commented 5 years ago

Hey Eric could you forward that email to me as well please? Im interested in what you have planned. Thanks!

On Tue, Aug 27, 2019, 2:23 AM Erik notifications@github.com wrote:

I sent you an email about the project I'm pondering. :)

— You are receiving this because you were assigned. Reply to this email directly, view it on GitHub https://github.com/fancompute/wavetorch/issues/1?email_source=notifications&email_token=ABLIFNMGO4P43WSW5JW5JJ3QGQGPVA5CNFSM4HIF33BKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOD5FBLUA#issuecomment-524948944, or mute the thread https://github.com/notifications/unsubscribe-auth/ABLIFNLAPRTLAKHHUS3IY4LQGQGPVANCNFSM4HIF33BA .

parenthetical-e commented 5 years ago

Done, @twhughes

ianwilliamson commented 5 years ago

This is now partially implemented. Currently, the individual time step is a primitive. This seems to help with memory utilization during training, especially with nonlinearity. Perhaps we could investigate if there would be significant performance benefits from adjoint-ing the time loop as well.