dgasmith / opt_einsum

⚡️Optimizing einsum functions in NumPy, Tensorflow, Dask, and more with contraction order optimization.
https://dgasmith.github.io/opt_einsum/
MIT License
863 stars 68 forks source link

Does having an efficient contraction path imply backward being efficient? #132

Closed philip-bl closed 3 years ago

philip-bl commented 4 years ago

Hi. I use opt_einsum and I love it. I use it in pytorch to train tensor network based models using SGD with back propagation. I have a question.

Suppose a layer of my model performs a complicated tensor network contraction in forward pass. The backward pass is automatically generated by pytorch. Is backward pass going to be efficient, or can it be significantly improved by rewriting it by hand and using opt_einsum

For example, consider

import torch

import opt_einsum as oe

inputs = [torch.randn(512, 2) for i in range(9)]
for input in inputs:
    input.requires_grad_()
parameters = [
    torch.randn(2, 7),
    torch.randn(7, 2, 7),
    torch.randn(7, 2, 7),
    torch.randn(7, 2, 7),
    torch.randn(7, 2, 10, 7),
    torch.randn(7, 2, 7),
    torch.randn(7, 2, 7),
    torch.randn(7, 2, 7),
    torch.randn(7, 2),
]
for param in parameters:
    param.requires_grad_()

# we know the following statement will be evaluated pretty efficiently                                                                                                                               
output = oe.contract(
    "bc,bd,be,bf,bg,bh,bi,bj,bk,cm,mdn,neo,ofp,pgzq,qhr,ris,sjt,tk->bz",
    *inputs,
    *parameters,
    optimize="auto"
)
# what about the following statement? will it be efficient?
# If I add memory_limit to contract invocation, what memory limit will I have here?                                                                      
output.backward(torch.ones_like(output))

I think many people would be interested in the answer. So if you could add the answer to your documentation, that would be great.

dgasmith commented 4 years ago

Thats a really good question. PyTorch will trace the function calls to build up a gradient history and apply each operation sequentially. So there is at least some assurance that it will not require naive memory or computational costs. However, I don't think that a backwards propagation has the same optimal contraction order as a forward contraction. I don't follow the ML literature enough to give you bounding statements here.

Keywords like memory limit will be "somewhat" followed as in large intermediates are likely avoided, but not guaranteed to be skipped.

@jcmgray Probably has more here.

jcmgray commented 4 years ago

My rough (possibly faulty!) understanding is this: when you back-propagate you basically need all the intermediates computed during the forward pass, plus the backwards propagating gradients (of equal size), so the memory will be at least twice as high. Also opt_einsum usually clears out intermediates from memory once they are no longer needed so given that these are now all retained the memory could in fact be quite a bit higher (in worst case linear with number of inputs).

dgasmith commented 4 years ago

Unfortunately the backprop will not be handled by opt_einsum so the opt_einsum-based caching will not come into play.

Re memory: the operation is at least played through so it doesn't try to do the contraction as one massive blob which would could cause naive (massive) memory requirements. So you are likely saved from many orders of magnitude more than expected.