lucidrains / x-transformers

A concise but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.63k stars 395 forks source link

Possible Bug for Residual #127

Closed XiaoWang-Github closed 1 year ago

XiaoWang-Github commented 1 year ago

In line 1066 in x_transformers.py residual = x Should it be residual = x.clone() ?

Immediately following that code In line 1071, you have x = pre_branch_norm(x). If you don't have the clone operation, pre_branch_norm will modify residual too since residual and x point to the data stored in the same memory space?

lucidrains commented 1 year ago

@XiaoWang-Github

this will help

import torch
x = torch.tensor([1.])

y = x
z = y

y = y + 1
x = x + 2

print(z, id(z)) # [1.], <addr1>
print(y, id(y)) # [2.], <addr2>
print(x, id(x)) # [3.], <addr3>

# on the other hand

import torch
x = torch.tensor([1.])

y = x
x.add_(1)

print(y, id(y)) # [2.], <addr>
print(x, id(x)) # [2.], <addr>
XiaoWang-Github commented 1 year ago

Ok. That makes sense. So the message I get from you is that the operations are out-of-place, instead of in-place. Therefore, modifying x out-of-place in later lines of the code will not affect the stored data for residual.