jhoffman / cycada_release

Code to accompany ICML 2018 paper
BSD 2-Clause "Simplified" License
561 stars 126 forks source link

Should we train feature&pixel level alternatively? #22

Closed Luodian closed 4 years ago

Luodian commented 5 years ago

I noticed that in your code, if we set discrim_feat = True, then we are training feature level adaption, otherwise, we are doing it at pixel level?

I am not sure if I understand this procedure correctly, but if I only use default script train_fcn_adda.sh, I can only get a pixel level trained model? and I need to train it again by loading former model into pixel level setting?

Also, the 'drn26' doesn't correctly be implemented in discrim_feat = True's case.

Maybe I am not clearly aware of the whole procedure, but I've got really confused.

# extract features
if discrim_feat:
    score_s, feat_s = net_src(im_s)
    score_s = Variable(score_s.data, requires_grad=False)
    f_s = Variable(feat_s.data, requires_grad=False)
else:
    score_s = Variable(net_src(im_s).data, requires_grad=False)
    f_s = score_s
dis_score_s = discriminator(f_s)

if discrim_feat:
    score_t, feat_t = net(im_t)
    score_t = Variable(score_t.data, requires_grad=False)
    f_t = Variable(feat_t.data, requires_grad=False)
else:
    score_t = Variable(net(im_t).data, requires_grad=False)
    f_t = score_t
dis_score_t = discriminator(f_t)

dis_pred_concat = torch.cat((dis_score_s, dis_score_t))

# prepare real and fake labels
batch_t, _, h, w = dis_score_t.size()
batch_s, _, _, _ = dis_score_s.size()
dis_label_concat = make_variable(
    torch.cat(
        [torch.ones(batch_s, h, w).long(),
         torch.zeros(batch_t, h, w).long()]
    ), requires_grad=False)

# compute loss for discriminator
loss_dis = supervised_loss(dis_pred_concat, dis_label_concat)
(lambda_d * loss_dis).backward()
losses_dis.append(loss_dis.item())

# optimize discriminator
opt_dis.step()
Luodian commented 5 years ago

Or do we only need to train a feature level adaption? since the stylized source images already contain the pixel level information?

jhoffman commented 4 years ago

Discrim_feat is a parameter given to the feature adaptation code to determine whether to use the penultimate layer activations (discrim_feat=True) or the logits (discrim_feat=False).