mkshing / ziplora-pytorch

Implementation of "ZipLoRA: Any Subject in Any Style by Effectively Merging LoRAs"
MIT License
481 stars 33 forks source link

Do not use Cosine Similarity #10

Closed crj1998 closed 7 months ago

crj1998 commented 7 months ago

The eq3 in paper use dot product while this codebase use Cosine Similarity. It will case gradient vanish.

import torch
import torch.nn.functional as F

chs = 8
merger_1 = torch.ones((chs,), requires_grad=True)
merger_2 = torch.ones((chs,), requires_grad=True)
optimizer = torch.optim.AdamW([merger_1, merger_2], lr=0.1)

for i in range(10):
    optimizer.zero_grad()
    # loss = torch.abs(merger_1 * merger_2).mean()
    loss = F.cosine_similarity(merger_1, merger_2, dim=0).abs()
    loss.backward()
    optimizer.step()
    print(i, loss.item())
    print(merger_1.grad)
    print(merger_1.detach())

if you use loss = F.cosine_similarity(merger_1, merger_2, dim=0).abs(), you will get

0 0.9999999403953552
tensor([7.4506e-09, 7.4506e-09, 7.4506e-09, 7.4506e-09, 7.4506e-09, 7.4506e-09,
        7.4506e-09, 7.4506e-09])
tensor([0.9563, 0.9563, 0.9563, 0.9563, 0.9563, 0.9563, 0.9563, 0.9563])
1 1.0000001192092896
tensor([-4.4703e-08, -4.4703e-08, -4.4703e-08, -4.4703e-08, -4.4703e-08,
        -4.4703e-08, -4.4703e-08, -4.4703e-08])
tensor([1.0029, 1.0029, 1.0029, 1.0029, 1.0029, 1.0029, 1.0029, 1.0029])
2 0.9999999403953552
tensor([1.4901e-08, 1.4901e-08, 1.4901e-08, 1.4901e-08, 1.4901e-08, 1.4901e-08,
        1.4901e-08, 1.4901e-08])
tensor([1.0209, 1.0209, 1.0209, 1.0209, 1.0209, 1.0209, 1.0209, 1.0209])
3 0.9999999403953552
tensor([0., 0., 0., 0., 0., 0., 0., 0.])
tensor([1.0348, 1.0348, 1.0348, 1.0348, 1.0348, 1.0348, 1.0348, 1.0348])
4 0.9999999403953552
tensor([0., 0., 0., 0., 0., 0., 0., 0.])
tensor([1.0459, 1.0459, 1.0459, 1.0459, 1.0459, 1.0459, 1.0459, 1.0459])
5 0.9999999403953552
tensor([2.2352e-08, 2.2352e-08, 2.2352e-08, 2.2352e-08, 2.2352e-08, 2.2352e-08,
        2.2352e-08, 2.2352e-08])
tensor([1.0393, 1.0393, 1.0393, 1.0393, 1.0393, 1.0393, 1.0393, 1.0393])
6 0.9999999403953552
tensor([0., 0., 0., 0., 0., 0., 0., 0.])
tensor([1.0334, 1.0334, 1.0334, 1.0334, 1.0334, 1.0334, 1.0334, 1.0334])
7 1.0000001192092896
tensor([-7.4506e-09, -7.4506e-09, -7.4506e-09, -7.4506e-09, -7.4506e-09,
        -7.4506e-09, -7.4506e-09, -7.4506e-09])
tensor([1.0329, 1.0329, 1.0329, 1.0329, 1.0329, 1.0329, 1.0329, 1.0329])
8 0.9999999403953552
tensor([7.4506e-09, 7.4506e-09, 7.4506e-09, 7.4506e-09, 7.4506e-09, 7.4506e-09,
        7.4506e-09, 7.4506e-09])
tensor([1.0279, 1.0279, 1.0279, 1.0279, 1.0279, 1.0279, 1.0279, 1.0279])
9 0.9999999403953552
tensor([1.4901e-08, 1.4901e-08, 1.4901e-08, 1.4901e-08, 1.4901e-08, 1.4901e-08,
        1.4901e-08, 1.4901e-08])
tensor([1.0151, 1.0151, 1.0151, 1.0151, 1.0151, 1.0151, 1.0151, 1.0151])

while use loss = torch.abs(merger_1 * merger_2).mean(), you will get

0 1.0
tensor([0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250])
tensor([0.8990, 0.8990, 0.8990, 0.8990, 0.8990, 0.8990, 0.8990, 0.8990])
1 0.8082009553909302
tensor([0.1124, 0.1124, 0.1124, 0.1124, 0.1124, 0.1124, 0.1124, 0.1124])
tensor([0.7985, 0.7985, 0.7985, 0.7985, 0.7985, 0.7985, 0.7985, 0.7985])
2 0.6376327872276306
tensor([0.0998, 0.0998, 0.0998, 0.0998, 0.0998, 0.0998, 0.0998, 0.0998])
tensor([0.6989, 0.6989, 0.6989, 0.6989, 0.6989, 0.6989, 0.6989, 0.6989])
3 0.48847696185112
tensor([0.0874, 0.0874, 0.0874, 0.0874, 0.0874, 0.0874, 0.0874, 0.0874])
tensor([0.6006, 0.6006, 0.6006, 0.6006, 0.6006, 0.6006, 0.6006, 0.6006])
4 0.3607187271118164
tensor([0.0751, 0.0751, 0.0751, 0.0751, 0.0751, 0.0751, 0.0751, 0.0751])
tensor([0.5041, 0.5041, 0.5041, 0.5041, 0.5041, 0.5041, 0.5041, 0.5041])
5 0.2540968358516693
tensor([0.0630, 0.0630, 0.0630, 0.0630, 0.0630, 0.0630, 0.0630, 0.0630])
tensor([0.4099, 0.4099, 0.4099, 0.4099, 0.4099, 0.4099, 0.4099, 0.4099])
6 0.16804958879947662
tensor([0.0512, 0.0512, 0.0512, 0.0512, 0.0512, 0.0512, 0.0512, 0.0512])
tensor([0.3188, 0.3188, 0.3188, 0.3188, 0.3188, 0.3188, 0.3188, 0.3188])
7 0.10166114568710327
tensor([0.0399, 0.0399, 0.0399, 0.0399, 0.0399, 0.0399, 0.0399, 0.0399])
tensor([0.2315, 0.2315, 0.2315, 0.2315, 0.2315, 0.2315, 0.2315, 0.2315])
8 0.05361458286643028
tensor([0.0289, 0.0289, 0.0289, 0.0289, 0.0289, 0.0289, 0.0289, 0.0289])
tensor([0.1489, 0.1489, 0.1489, 0.1489, 0.1489, 0.1489, 0.1489, 0.1489])
9 0.022164179012179375
tensor([0.0186, 0.0186, 0.0186, 0.0186, 0.0186, 0.0186, 0.0186, 0.0186])
tensor([0.0717, 0.0717, 0.0717, 0.0717, 0.0717, 0.0717, 0.0717, 0.0717])

Obviously, there exist gradient vanish problem for cosine similarity.

mkshing commented 7 months ago

@crj1998 thanks for your investigation. I agree with your point. Fixed it on https://github.com/mkshing/ziplora-pytorch/pull/12