Closed philip-bl closed 3 years ago
Thats a really good question. PyTorch will trace the function calls to build up a gradient history and apply each operation sequentially. So there is at least some assurance that it will not require naive memory or computational costs. However, I don't think that a backwards propagation has the same optimal contraction order as a forward contraction. I don't follow the ML literature enough to give you bounding statements here.
Keywords like memory limit will be "somewhat" followed as in large intermediates are likely avoided, but not guaranteed to be skipped.
@jcmgray Probably has more here.
My rough (possibly faulty!) understanding is this: when you back-propagate you basically need all the intermediates computed during the forward pass, plus the backwards propagating gradients (of equal size), so the memory will be at least twice as high. Also opt_einsum
usually clears out intermediates from memory once they are no longer needed so given that these are now all retained the memory could in fact be quite a bit higher (in worst case linear with number of inputs).
Unfortunately the backprop will not be handled by opt_einsum
so the opt_einsum
-based caching will not come into play.
Re memory: the operation is at least played through so it doesn't try to do the contraction as one massive blob which would could cause naive (massive) memory requirements. So you are likely saved from many orders of magnitude more than expected.
Hi. I use
opt_einsum
and I love it. I use it in pytorch to train tensor network based models using SGD with back propagation. I have a question.Suppose a layer of my model performs a complicated tensor network contraction in forward pass. The backward pass is automatically generated by pytorch. Is backward pass going to be efficient, or can it be significantly improved by rewriting it by hand and using
opt_einsum
For example, consider
I think many people would be interested in the answer. So if you could add the answer to your documentation, that would be great.