gnobitab / RectifiedFlow

Official Implementation of Rectified Flow (ICLR2023 Spotlight)
663 stars 40 forks source link

Implementation of feature loss, Equation 4 in the paper. #15

Open bfs18 opened 7 months ago

bfs18 commented 7 months ago

Hi @gnobitab , I implemented feature loss by myself, however, it did not work properly. Could you provide some comments for my pseudo code?

import torch
import torch.nn.functional as F

def get_feature_weight(S):

    def _feature_func(x):
        feature = feature_extractor(x)  # shape [batch_size, feature_dim, H, W]
        feature = feature.sum(dim=(0, 2, 3))
        return feature  # shape [feature_dim]

    S = S.requires_grad_(True)  # shape [batch_size, dim, H, W]
    w = torch.autograd.functional.jacobian(_feature_func, S)   # shape [feature_dim, batch_size, dim, H, W]
    return w.transpose(0, 1).detach()   # shape [batch_size, feature_dim, dim, H, W]

w = get_feature_weight(z_t)
w_target = torch.einsum('bdchw,bchw->bdhw', w, target)
w_pred = torch.einsum('bdchw,bchw->bdhw', w, pred)
loss = F.mse_loss(target, pred)