simab = torch.nn.functional.cosine_similarity(theta_a, theta_b, dim=0)
...
k = (simab - sims.min())/(sims.max() - sims.min())
k = k - alpha # what this mean? - max k is 1. (1 - alpha) ??
k = k.clip(min=0.0, max=1.0) # k could be minus, clip needed
theta_0[key] = weighted_sum(theta_0[key], theta_1[key], k) # k is a tensor
simplified fixed code
...
simab = torch.nn.functional.cosine_similarity(theta_a, theta_b, dim=0).abs() # use abs()
...
k = (simab - sims.min())/(sims.max() - sims.min()) # k always k > 0 and k < 1
k = k.mean() * alpha if "Simple" in calcmode else k * alpha # k.mean() * alpha for Simple Cosine.
theta_0[key] = weighted_sum(theta_0[key], theta_1[key], k) # k is a float if Simple Cosne
Cosine merge by recoilme https://github.com/recoilme/losslessmix/blob/master/weightedsim.py#L41-L56
original code use
simplifiedfixed codesimplified cosine merge is much faster.