Open sliorde opened 2 years ago
More details: I understand that einsum
calls bmm
. So I checked if this problem happens with bmm
, and indeed it does:
from torch import tensor, zeros, bmm
x = zeros((1, 6, 2))
x[0, 0, 0] = 1.0791796445846558
x[0, 0, 1] = 0.30579063296318054
y = zeros((1, 2, 34))
y[0, 0, 0] = -0.14987720549106598
y[0, 1, 0] = 0.9887046217918396
# the following four numbers should be equal, but they are not.
a = bmm(x, y )[0, 0, 0].item() # =0.14059218764305115
b = bmm(x, y[:, :, :33])[0, 0, 0].item() # =0.14059217274188995
c = bmm(x[:, :5], y )[0, 0, 0].item() # =0.14059217274188995
d = bmm(x[:, :5], y[:, :, :33])[0, 0, 0].item() # =0.14059217274188995
a
different from b
, it is also different from c
and d
.bmm
with matmul
(and remove the batch dimension), then the problem does not occur.34
is reduced, or 6
is reduced, or 2
is reduced), the problem does not occur.Thanks @qqaatw . The link you sent indeed seems very related. It talks about getting different numerical results in these two cases: (1) apply a function to a slice of a tensor; (2) apply the function to the entire tensor and then take only the corresponding slice/element.
Although our current issue isn't an exact match to this description, it seems close enough and is probably a manifestation of the same underlying phenomenon.
Should we close this issue?
Side note: I believe the linked notes page from @qqaatw has an error. Here is a quote from that page:
E.g. let
A
be a 2-dimentional tensor.A.sum(-1)[0]
is not guaranteed to be bitwise equal toA[:,0].sum()
.
But these two are not the same thing at all. The former is the sum of a row of A
, and the latter is the sum of a column. The latter should have been something like A[0,:].sum()
.
(opened an issue for the side note... #80940)
I'm afraid this is the same root cause indeed. Depending on the size/number of dimensions, different algorithm might get selected leading to small differences.
🐛 Describe the bug
I have two tensors,
x
andy
. The first has shape(2, 3, 2)
, the second has shape(34, 2)
. I useeinsum
to calculate the dot product between each of the six 2-dimensional vectors that lie in the last dimension ofx
, and each of the 34 vectors that lie in the last dimension ofy
. The bug is that the result of the dot product betweenx[0, 0]
andy[0]
changes if we ignore the last vector ofy
, i.e. if we takey[:33]
instead ofy
. This is undesired behavior (I think).See here:
I believe this is a minimal example (at least, local minimum). If I take
x
to be 1d or 2d instead of 3d, the bug does not occur. If I takex
andy
to have last dimension of size 1, the bug does not occur. If I change any of the non-zero entries ofx
andy
to value zero, the bug does not occur. If I do a "manual" dot product instead, like this:(x[None,...]*y[:33,None,None,:]).sum(3)[0,1,0].item()
, the bug does not occur (and we get the value0.14059217274188995
, equal tob
above). If I put the tensors on GPU (.to('cuda')
), the bug does not occur (and again we get the value ofb
above). If I use numpy, the bug does not occur (but we get a different value:0.14059218275815866
).Versions
PyTorch version: 1.11.0+cu113 Is debug build: False CUDA used to build PyTorch: 11.3 ROCM used to build PyTorch: N/A
OS: Ubuntu 18.04.5 LTS (x86_64) GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0 Clang version: 6.0.0-1ubuntu2 (tags/RELEASE_600/final) CMake version: version 3.22.5 Libc version: glibc-2.26
Python version: 3.7.13 (default, Apr 24 2022, 01:04:09) [GCC 7.5.0] (64-bit runtime) Python platform: Linux-5.4.188+-x86_64-with-Ubuntu-18.04-bionic Is CUDA available: True CUDA runtime version: 11.1.105 GPU models and configuration: GPU 0: Tesla T4 Nvidia driver version: 460.32.03 cuDNN version: Probably one of the following: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5 /usr/lib/x86_64-linux-gnu/libcudnn.so.8.0.5 /usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.0.5 /usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.0.5 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.0.5 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.0.5 /usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.0.5 /usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.0.5 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True
Versions of relevant libraries: [pip3] numpy==1.21.6 [pip3] torch==1.11.0+cu113 [pip3] torchaudio==0.11.0+cu113 [pip3] torchsummary==1.5.1 [pip3] torchtext==0.12.0 [pip3] torchvision==0.12.0+cu113 [conda] Could not collect