Open heitorschueroff opened 3 years ago
Re: second option, making tensors contiguous is not a requirement for bmm, einsum is needlessly doing this, as illustrated by this script:
import time
import torch
import numpy as np
nb = 80
dtype = torch.float64
torch.set_num_threads(1)
mat1 = torch.rand((nb, nb, nb, nb), dtype=dtype)
mat2 = torch.rand((nb, nb), dtype=dtype)
with torch.autograd.profiler.profile(record_shapes=True) as p:
res3 = torch.einsum("jk,ijkl->il", mat2, mat1)
print(p.key_averages(group_by_input_shape=True))
mat2_view = mat2.view(1, 1, nb*nb)
mat1_view = mat1.view(nb, nb*nb, nb)
with torch.autograd.profiler.profile(record_shapes=True) as p:
res4 = torch.matmul(mat2_view, mat1_view).view(nb, nb)
print(p.key_averages(group_by_input_shape=True))
print(res3.size(), res4.size(), (res3-res4).abs().max())
The second version using matmul
is not calling contiguous
or copy_
, and is almost 10 times faster than the first version. I could have written second version to use bmm directly, but it would require me to expand mat2 to explicitly have the same batch dimension, which would not change the perf numbers (matmul is doing this internally), but would make the code less clear.
This won't work for case 4 in #57121 where contraction dimensions are discontiguous.
@ngimel That is very interesting, I will take a look at where we are copying data unnecessarily in torch.einsum
. But as you mentioned, case 4 wouldn't work because the contraction dimensions are discontiguous. In this case, it seems to me that we would still need a way to compute sum of products over discontiguous dimensions to get better performance, is that correct?
Yes, einsum contracting over discontiguous dimensions probably needs custom kernels - if we can make performant kernel for this, it's non-trivial.
The first should now be landed with #84890.
Currently,
torch.einsum
is computed by performing tensor contractions from left to right and callingtorch.bmm
for each contraction, unless there are no dimensions to be summed out in which casetorch.mul
is called instead. There are two main ways in whichtorch.einsum
performance can be improved:1. Optimize contraction path
Instead of performing contractions from left to right, we can compute an optimal (or near-optimal) contraction path. NumPy and TensorFlow do this by calling opt_einsum library to get the path. The PR https://github.com/pytorch/pytorch/pull/60191 adds initial support for this by accepting a contraction path in the same format as given by
opt_einsum.contract_path(...)[0]
. However, we should consider either adding a dependency onopt_einsum
or implementing our own algorithms for computing the optimal contraction path.2. Optimize tensor contractions for non BLAS or discontiguous cases
Instead of always calling into
torch.bmm
for contracting two tensors which can cause expensive data movements for certain tensor contractions, we should do what NumPy does: if the contraction matches a BLAS operation and the dimensions are contiguous, calltorch.tensordot
, otherwise call a custom function to compute a sum_of_products over discontiguous dimensions. The issue https://github.com/pytorch/pytorch/issues/57121 illustrates a case where PyTorch is much slower than NumPy because the dimensions are discontiguous. Asum_of_products
function could replace https://github.com/pytorch/pytorch/blob/d9e7df707bdc737a10b3af15f08143cc300ace39/aten/src/ATen/native/Linear.cpp#L47 which is also used bytorch.bilinear
's forward and backward function.cc @jianyuh @nikitaved @pearu @mruberry @heitorschueroff @walterddr @IvanYashchuk @xwang233 @Lezcano @VitalyFedyunin @ngimel
see also: https://github.com/pytorch/pytorch/issues/21760