imlixinyang / HiSD

Code for "Image-to-image Translation via Hierarchical Style Disentanglement" (CVPR 2021 Oral).
Other
392 stars 47 forks source link

How many tags can this project train at the same time? #27

Open datar001 opened 3 years ago

datar001 commented 3 years ago

Hi, thanks for your sharing. How many tags have you tried to train? What's the relation between the number of tags and that of training iterations? And How many tags will you recommend at the once training?

imlixinyang commented 3 years ago

I've succeeded to train 6 tags at the same time. In experiment, I found 50k per tag is enough (i.e., 20k for 6 tags). HiSD supports various numbers of tags but you should increase the training iteration and the model capacity. Using gradient accumulation and train all tags in one iteration is also important (so you need to change the code a little).

datar001 commented 3 years ago

Thanks for your reply. Is it right about "the gradient accumulation and all tags in one iteration"? image image And '20k for 6 tags' is the typo? The official repo is 200k for 3 tags with 7 attributions. Then is there a better performance when we train fewer tags?

imlixinyang commented 3 years ago

Sorry for the typo, it should be 200k for 3 tags with 7 attributes. You get the idea of the gradient accumulation in a right way, and you can clarify the update code like:

    def update(self, x, y, i, j, j_trg, iterations):

        this_model = self.models.module if self.multi_gpus else self.models

        # gen 
        for p in this_model.dis.parameters():
            p.requires_grad = False
        for p in this_model.gen.parameters():
            p.requires_grad = True

        self.loss_gen_adv, self.loss_gen_sty, self.loss_gen_rec, \
        x_trg, x_cyc, s, s_trg = self.models((x, y, i, j, j_trg), mode='gen')

        self.loss_gen_adv = self.loss_gen_adv.mean()
        self.loss_gen_sty = self.loss_gen_sty.mean()
        self.loss_gen_rec = self.loss_gen_rec.mean()

        # dis
        for p in this_model.dis.parameters():
            p.requires_grad = True
        for p in this_model.gen.parameters():
            p.requires_grad = False

        self.loss_dis_adv = self.models((x, x_trg, x_cyc, s, s_trg, y, i, j, j_trg), mode='dis')
        self.loss_dis_adv = self.loss_dis_adv.mean()

        if (iterations + 1) % self.tag_num == 0:
            nn.utils.clip_grad_norm_(this_model.gen.parameters(), 100)
            nn.utils.clip_grad_norm_(this_model.dis.parameters(), 100)
            self.gen_opt.step()
            self.dis_opt.step()
            self.gen_opt.zero_grad()
            self.dis_opt.zero_grad()

            update_average(this_model.gen_test, this_model.gen)

        return self.loss_gen_adv.item(), \
               self.loss_gen_sty.item(), \
               self.loss_gen_rec.item(), \
               self.loss_dis_adv.item()

And you need to decrease the learning rate before backward (maybe lr/tag_num) since the gradient by 'sum' rather than 'average'.