Open bfs18 opened 7 months ago
Hi @gnobitab , I implemented feature loss by myself, however, it did not work properly. Could you provide some comments for my pseudo code?
import torch import torch.nn.functional as F def get_feature_weight(S): def _feature_func(x): feature = feature_extractor(x) # shape [batch_size, feature_dim, H, W] feature = feature.sum(dim=(0, 2, 3)) return feature # shape [feature_dim] S = S.requires_grad_(True) # shape [batch_size, dim, H, W] w = torch.autograd.functional.jacobian(_feature_func, S) # shape [feature_dim, batch_size, dim, H, W] return w.transpose(0, 1).detach() # shape [batch_size, feature_dim, dim, H, W] w = get_feature_weight(z_t) w_target = torch.einsum('bdchw,bchw->bdhw', w, target) w_pred = torch.einsum('bdchw,bchw->bdhw', w, pred) loss = F.mse_loss(target, pred)
Hi @gnobitab , I implemented feature loss by myself, however, it did not work properly. Could you provide some comments for my pseudo code?