arogozhnikov / einops

Flexible and powerful tensor operations for readable and reliable code (for pytorch, jax, TF and others)
https://einops.rocks
MIT License
8.48k stars 352 forks source link

[BUG] Result of einops.repeat differs by length #346

Closed Phosphor-Bai closed 4 hours ago

Phosphor-Bai commented 4 hours ago

Describe the bug When using einops.repeat on a dimension with length being 1, the values of the repeated items is interlocked, but if the length is not 1, the values are not interlocked. Btw, by using repeat_interleave in pytorch, in all cases values are not interlocked

Reproduction steps Steps to reproduce the behavior:

import torch
import einops

# repeated dimension length = 1
a = torch.LongTensor([[1], [2], [3]])
b = einops.repeat(a, 'n v -> n (v x)', x=3)
b[0, 0] = 9
print(b)   # You'll get [[9, 9, 9], [2, 2, 2], [3, 3, 3]]
print(a)   # You'll get [[9], [2], [3]]

# repeated dimension length > 1
c = torch.LongTensor([[1, 2], [3, 4], [5, 6]])
d = einops.repeat(d, 'n v -> n (v x)', x=3)
d[0, 0] = 9
print(d)   # You'll get[[9, 1, 1, 2, 2, 2], [3, 3, 3, 4, 4, 4], [5, 5, 5, 6, 6, 6]])
print(c)   # You'll get [[1, 2], [3, 4], [5, 6]]

# repeated dimension length = 1 in pytorch
e = torch.LongTensor([[1], [2], [3]])
f = e.repeat_interleave(3, dim=1)
f[0, 0] = 9
print(f)    # You'll get [[9, 1, 1], [2, 2, 2], [3, 3, 3]]
print(e)   # You'll get [[1], [2], [3]]

Expected behavior The result of a and b are not consistent with those in c and d. I'm confused about why dim1 in b shares the same value and is also shared with a.

Your platform I reproduced this problem both in python3.9 and python3.10, with einops version all being 0.8.0

arogozhnikov commented 4 hours ago

hi @Phosphor-Bai

That's how it should work.

When possible, einops creates a view to existing tensor. To put it simply, values are actually pointing to the values in input tensor. Since you use repeat, several elements can point to the same original value. This strategy allows spending ~zero additional memory and perform zero copying in most cases.

When creating view is impossible, einops creates a new tensor (in other words, materializes result).

For instance np.reshape/torch.reshape has the same semantics: view when possible, otherwise create a copy.

To understand why sometimes creating view is impossible, you need to understand how striding works: https://stackoverflow.com/questions/53097952/how-to-understand-numpy-strides-for-layman Then you should be able to see how exactly this worked in the examples you posted above.