Open harryin212 opened 3 years ago
It is because the the parameters that do not need to do gradient calculation are also backward
The following codes work:
with torch.no_grad():
cx_combine = (1. - weight_sp) * cx_feat + weight_sp * cx_sp
k_max_NC, _ = torch.max(cx_combine, dim=2, keepdim=True)
cx = k_max_NC.mean(dim=1)
cx_loss = torch.mean(-torch.log(cx + 1e-5))
return cx_loss
hi, when I try to test the cobi loss on my srcnn model, I found it ran out of menory my image size is 128*128 and batch size is 1, test on a gtx1080 gpu can u tell me how to avoid oom here's my error code: