Closed crj1998 closed 7 months ago
The eq3 in paper use dot product while this codebase use Cosine Similarity. It will case gradient vanish.
import torch import torch.nn.functional as F chs = 8 merger_1 = torch.ones((chs,), requires_grad=True) merger_2 = torch.ones((chs,), requires_grad=True) optimizer = torch.optim.AdamW([merger_1, merger_2], lr=0.1) for i in range(10): optimizer.zero_grad() # loss = torch.abs(merger_1 * merger_2).mean() loss = F.cosine_similarity(merger_1, merger_2, dim=0).abs() loss.backward() optimizer.step() print(i, loss.item()) print(merger_1.grad) print(merger_1.detach())
if you use loss = F.cosine_similarity(merger_1, merger_2, dim=0).abs(), you will get
loss = F.cosine_similarity(merger_1, merger_2, dim=0).abs()
0 0.9999999403953552 tensor([7.4506e-09, 7.4506e-09, 7.4506e-09, 7.4506e-09, 7.4506e-09, 7.4506e-09, 7.4506e-09, 7.4506e-09]) tensor([0.9563, 0.9563, 0.9563, 0.9563, 0.9563, 0.9563, 0.9563, 0.9563]) 1 1.0000001192092896 tensor([-4.4703e-08, -4.4703e-08, -4.4703e-08, -4.4703e-08, -4.4703e-08, -4.4703e-08, -4.4703e-08, -4.4703e-08]) tensor([1.0029, 1.0029, 1.0029, 1.0029, 1.0029, 1.0029, 1.0029, 1.0029]) 2 0.9999999403953552 tensor([1.4901e-08, 1.4901e-08, 1.4901e-08, 1.4901e-08, 1.4901e-08, 1.4901e-08, 1.4901e-08, 1.4901e-08]) tensor([1.0209, 1.0209, 1.0209, 1.0209, 1.0209, 1.0209, 1.0209, 1.0209]) 3 0.9999999403953552 tensor([0., 0., 0., 0., 0., 0., 0., 0.]) tensor([1.0348, 1.0348, 1.0348, 1.0348, 1.0348, 1.0348, 1.0348, 1.0348]) 4 0.9999999403953552 tensor([0., 0., 0., 0., 0., 0., 0., 0.]) tensor([1.0459, 1.0459, 1.0459, 1.0459, 1.0459, 1.0459, 1.0459, 1.0459]) 5 0.9999999403953552 tensor([2.2352e-08, 2.2352e-08, 2.2352e-08, 2.2352e-08, 2.2352e-08, 2.2352e-08, 2.2352e-08, 2.2352e-08]) tensor([1.0393, 1.0393, 1.0393, 1.0393, 1.0393, 1.0393, 1.0393, 1.0393]) 6 0.9999999403953552 tensor([0., 0., 0., 0., 0., 0., 0., 0.]) tensor([1.0334, 1.0334, 1.0334, 1.0334, 1.0334, 1.0334, 1.0334, 1.0334]) 7 1.0000001192092896 tensor([-7.4506e-09, -7.4506e-09, -7.4506e-09, -7.4506e-09, -7.4506e-09, -7.4506e-09, -7.4506e-09, -7.4506e-09]) tensor([1.0329, 1.0329, 1.0329, 1.0329, 1.0329, 1.0329, 1.0329, 1.0329]) 8 0.9999999403953552 tensor([7.4506e-09, 7.4506e-09, 7.4506e-09, 7.4506e-09, 7.4506e-09, 7.4506e-09, 7.4506e-09, 7.4506e-09]) tensor([1.0279, 1.0279, 1.0279, 1.0279, 1.0279, 1.0279, 1.0279, 1.0279]) 9 0.9999999403953552 tensor([1.4901e-08, 1.4901e-08, 1.4901e-08, 1.4901e-08, 1.4901e-08, 1.4901e-08, 1.4901e-08, 1.4901e-08]) tensor([1.0151, 1.0151, 1.0151, 1.0151, 1.0151, 1.0151, 1.0151, 1.0151])
while use loss = torch.abs(merger_1 * merger_2).mean(), you will get
loss = torch.abs(merger_1 * merger_2).mean()
0 1.0 tensor([0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]) tensor([0.8990, 0.8990, 0.8990, 0.8990, 0.8990, 0.8990, 0.8990, 0.8990]) 1 0.8082009553909302 tensor([0.1124, 0.1124, 0.1124, 0.1124, 0.1124, 0.1124, 0.1124, 0.1124]) tensor([0.7985, 0.7985, 0.7985, 0.7985, 0.7985, 0.7985, 0.7985, 0.7985]) 2 0.6376327872276306 tensor([0.0998, 0.0998, 0.0998, 0.0998, 0.0998, 0.0998, 0.0998, 0.0998]) tensor([0.6989, 0.6989, 0.6989, 0.6989, 0.6989, 0.6989, 0.6989, 0.6989]) 3 0.48847696185112 tensor([0.0874, 0.0874, 0.0874, 0.0874, 0.0874, 0.0874, 0.0874, 0.0874]) tensor([0.6006, 0.6006, 0.6006, 0.6006, 0.6006, 0.6006, 0.6006, 0.6006]) 4 0.3607187271118164 tensor([0.0751, 0.0751, 0.0751, 0.0751, 0.0751, 0.0751, 0.0751, 0.0751]) tensor([0.5041, 0.5041, 0.5041, 0.5041, 0.5041, 0.5041, 0.5041, 0.5041]) 5 0.2540968358516693 tensor([0.0630, 0.0630, 0.0630, 0.0630, 0.0630, 0.0630, 0.0630, 0.0630]) tensor([0.4099, 0.4099, 0.4099, 0.4099, 0.4099, 0.4099, 0.4099, 0.4099]) 6 0.16804958879947662 tensor([0.0512, 0.0512, 0.0512, 0.0512, 0.0512, 0.0512, 0.0512, 0.0512]) tensor([0.3188, 0.3188, 0.3188, 0.3188, 0.3188, 0.3188, 0.3188, 0.3188]) 7 0.10166114568710327 tensor([0.0399, 0.0399, 0.0399, 0.0399, 0.0399, 0.0399, 0.0399, 0.0399]) tensor([0.2315, 0.2315, 0.2315, 0.2315, 0.2315, 0.2315, 0.2315, 0.2315]) 8 0.05361458286643028 tensor([0.0289, 0.0289, 0.0289, 0.0289, 0.0289, 0.0289, 0.0289, 0.0289]) tensor([0.1489, 0.1489, 0.1489, 0.1489, 0.1489, 0.1489, 0.1489, 0.1489]) 9 0.022164179012179375 tensor([0.0186, 0.0186, 0.0186, 0.0186, 0.0186, 0.0186, 0.0186, 0.0186]) tensor([0.0717, 0.0717, 0.0717, 0.0717, 0.0717, 0.0717, 0.0717, 0.0717])
Obviously, there exist gradient vanish problem for cosine similarity.
@crj1998 thanks for your investigation. I agree with your point. Fixed it on https://github.com/mkshing/ziplora-pytorch/pull/12
The eq3 in paper use dot product while this codebase use Cosine Similarity. It will case gradient vanish.
if you use
loss = F.cosine_similarity(merger_1, merger_2, dim=0).abs()
, you will getwhile use
loss = torch.abs(merger_1 * merger_2).mean()
, you will getObviously, there exist gradient vanish problem for cosine similarity.