Closed XiaoWang-Github closed 1 year ago
@XiaoWang-Github
this will help
import torch
x = torch.tensor([1.])
y = x
z = y
y = y + 1
x = x + 2
print(z, id(z)) # [1.], <addr1>
print(y, id(y)) # [2.], <addr2>
print(x, id(x)) # [3.], <addr3>
# on the other hand
import torch
x = torch.tensor([1.])
y = x
x.add_(1)
print(y, id(y)) # [2.], <addr>
print(x, id(x)) # [2.], <addr>
Ok. That makes sense. So the message I get from you is that the operations are out-of-place, instead of in-place. Therefore, modifying x out-of-place in later lines of the code will not affect the stored data for residual.
In line 1066 in x_transformers.py
residual = x
Should it beresidual = x.clone()
?Immediately following that code In line 1071, you have
x = pre_branch_norm(x)
. If you don't have the clone operation, pre_branch_norm will modify residual too since residual and x point to the data stored in the same memory space?