pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
84.5k stars 22.76k forks source link

Optimize torch.einsum #60295

Open heitorschueroff opened 3 years ago

heitorschueroff commented 3 years ago

Currently, torch.einsum is computed by performing tensor contractions from left to right and calling torch.bmm for each contraction, unless there are no dimensions to be summed out in which case torch.mul is called instead. There are two main ways in which torch.einsum performance can be improved:

1. Optimize contraction path

Instead of performing contractions from left to right, we can compute an optimal (or near-optimal) contraction path. NumPy and TensorFlow do this by calling opt_einsum library to get the path. The PR https://github.com/pytorch/pytorch/pull/60191 adds initial support for this by accepting a contraction path in the same format as given by opt_einsum.contract_path(...)[0]. However, we should consider either adding a dependency on opt_einsum or implementing our own algorithms for computing the optimal contraction path.

2. Optimize tensor contractions for non BLAS or discontiguous cases

Instead of always calling into torch.bmm for contracting two tensors which can cause expensive data movements for certain tensor contractions, we should do what NumPy does: if the contraction matches a BLAS operation and the dimensions are contiguous, call torch.tensordot, otherwise call a custom function to compute a sum_of_products over discontiguous dimensions. The issue https://github.com/pytorch/pytorch/issues/57121 illustrates a case where PyTorch is much slower than NumPy because the dimensions are discontiguous. A sum_of_products function could replace https://github.com/pytorch/pytorch/blob/d9e7df707bdc737a10b3af15f08143cc300ace39/aten/src/ATen/native/Linear.cpp#L47 which is also used by torch.bilinear's forward and backward function.

cc @jianyuh @nikitaved @pearu @mruberry @heitorschueroff @walterddr @IvanYashchuk @xwang233 @Lezcano @VitalyFedyunin @ngimel

see also: https://github.com/pytorch/pytorch/issues/21760

ngimel commented 3 years ago

Re: second option, making tensors contiguous is not a requirement for bmm, einsum is needlessly doing this, as illustrated by this script:

import time
import torch
import numpy as np

nb = 80
dtype = torch.float64
torch.set_num_threads(1)
mat1 = torch.rand((nb, nb, nb, nb), dtype=dtype)
mat2 = torch.rand((nb, nb), dtype=dtype)

with torch.autograd.profiler.profile(record_shapes=True) as p:
    res3 = torch.einsum("jk,ijkl->il", mat2, mat1)
print(p.key_averages(group_by_input_shape=True))

mat2_view = mat2.view(1, 1, nb*nb)
mat1_view = mat1.view(nb, nb*nb, nb)
with torch.autograd.profiler.profile(record_shapes=True) as p:
    res4 = torch.matmul(mat2_view, mat1_view).view(nb, nb)
print(p.key_averages(group_by_input_shape=True))

print(res3.size(), res4.size(), (res3-res4).abs().max())

The second version using matmul is not calling contiguous or copy_, and is almost 10 times faster than the first version. I could have written second version to use bmm directly, but it would require me to expand mat2 to explicitly have the same batch dimension, which would not change the perf numbers (matmul is doing this internally), but would make the code less clear. This won't work for case 4 in #57121 where contraction dimensions are discontiguous.

heitorschueroff commented 3 years ago

@ngimel That is very interesting, I will take a look at where we are copying data unnecessarily in torch.einsum. But as you mentioned, case 4 wouldn't work because the contraction dimensions are discontiguous. In this case, it seems to me that we would still need a way to compute sum of products over discontiguous dimensions to get better performance, is that correct?

ngimel commented 3 years ago

Yes, einsum contracting over discontiguous dimensions probably needs custom kernels - if we can make performant kernel for this, it's non-trivial.

janeyx99 commented 2 years ago

The first should now be landed with #84890.