Open datar001 opened 3 years ago
I've succeeded to train 6 tags at the same time. In experiment, I found 50k per tag is enough (i.e., 20k for 6 tags). HiSD supports various numbers of tags but you should increase the training iteration and the model capacity. Using gradient accumulation and train all tags in one iteration is also important (so you need to change the code a little).
Thanks for your reply. Is it right about "the gradient accumulation and all tags in one iteration"? And '20k for 6 tags' is the typo? The official repo is 200k for 3 tags with 7 attributions. Then is there a better performance when we train fewer tags?
Sorry for the typo, it should be 200k for 3 tags with 7 attributes. You get the idea of the gradient accumulation in a right way, and you can clarify the update code like:
def update(self, x, y, i, j, j_trg, iterations):
this_model = self.models.module if self.multi_gpus else self.models
# gen
for p in this_model.dis.parameters():
p.requires_grad = False
for p in this_model.gen.parameters():
p.requires_grad = True
self.loss_gen_adv, self.loss_gen_sty, self.loss_gen_rec, \
x_trg, x_cyc, s, s_trg = self.models((x, y, i, j, j_trg), mode='gen')
self.loss_gen_adv = self.loss_gen_adv.mean()
self.loss_gen_sty = self.loss_gen_sty.mean()
self.loss_gen_rec = self.loss_gen_rec.mean()
# dis
for p in this_model.dis.parameters():
p.requires_grad = True
for p in this_model.gen.parameters():
p.requires_grad = False
self.loss_dis_adv = self.models((x, x_trg, x_cyc, s, s_trg, y, i, j, j_trg), mode='dis')
self.loss_dis_adv = self.loss_dis_adv.mean()
if (iterations + 1) % self.tag_num == 0:
nn.utils.clip_grad_norm_(this_model.gen.parameters(), 100)
nn.utils.clip_grad_norm_(this_model.dis.parameters(), 100)
self.gen_opt.step()
self.dis_opt.step()
self.gen_opt.zero_grad()
self.dis_opt.zero_grad()
update_average(this_model.gen_test, this_model.gen)
return self.loss_gen_adv.item(), \
self.loss_gen_sty.item(), \
self.loss_gen_rec.item(), \
self.loss_dis_adv.item()
And you need to decrease the learning rate before backward (maybe lr/tag_num) since the gradient by 'sum' rather than 'average'.
Hi, thanks for your sharing. How many tags have you tried to train? What's the relation between the number of tags and that of training iterations? And How many tags will you recommend at the once training?