Just one tip. the proposed environment doesn't work for me. So I build newest environment with cu 12 and transformer 4.24.0 clip-anytorch 2.6.0. etc. At the beginning there is something wrong with torch.autograd.grad(loss.requiresgrad(True), [x])[0] in function update_loss_self_cross. But I made a modification and replace loss with it's mathematical equivalence. Now it works for me. Here is the part of code I use to replace the original torch.autograd.grad(loss.requiresgrad(True), [x])[0] :
grad_cond1 = torch.autograd.grad(loss1.requiresgrad(True), [x], retain_graph=True)[0]
grad_cond2 = torch.autograd.grad(loss2.requiresgrad(True), [x], retain_graph=True)[0]
grad_cond3 = torch.autograd.grad(loss3.requiresgrad(True), [x])[0]
x = x - grad_cond1 - grad_cond2 - grad_cond3
Just one tip. the proposed environment doesn't work for me. So I build newest environment with cu 12 and transformer 4.24.0 clip-anytorch 2.6.0. etc. At the beginning there is something wrong with torch.autograd.grad(loss.requiresgrad(True), [x])[0] in function update_loss_self_cross. But I made a modification and replace loss with it's mathematical equivalence. Now it works for me. Here is the part of code I use to replace the original torch.autograd.grad(loss.requiresgrad(True), [x])[0] : grad_cond1 = torch.autograd.grad(loss1.requiresgrad(True), [x], retain_graph=True)[0] grad_cond2 = torch.autograd.grad(loss2.requiresgrad(True), [x], retain_graph=True)[0] grad_cond3 = torch.autograd.grad(loss3.requiresgrad(True), [x])[0] x = x - grad_cond1 - grad_cond2 - grad_cond3