Closed sdogsq closed 2 years ago
The main reason is that it is possible to efficiently reconstruct the forward pass during the backward pass. Doing so means we don't need to hold intermediate values in memory.
This is actually a very special case of the "continuous adjoint method" sometimes used in differential equations (e.g. as popularised for neural ODEs; also see Chapter 5 of https://arxiv.org/abs/2202.02435). Although in our case, because of the piecewise linear structure, we can recompute things without suffering any numerical truncation error. (Only floating point error, which usually isn't that bad.)
Really insightful views! I'll read this paper carefully.
Cheers!
Hi, Patrick,
Thanks a lot for your nice work and detailed code comments! I have a very naive question: why do not use pytorch autograd directly for backward process? Since I see the tensor operations are all like
addcmul
in forward process. I have some simple reasons but I am not sure if they are correct.I am a newbie in custom pytorch functions. I would appreciate it if you could kindly share some opinions. Thank you again!