hytseng0509 / CrossDomainFewShot

Cross-Domain Few-Shot Classification via Learned Feature-Wise Transformation (ICLR 2020 spotlight)
329 stars 62 forks source link

more detail explanation for "create_graph=True" and "weight.fast" #40

Open fikry102 opened 2 years ago

fikry102 commented 2 years ago

LFTNet.py, "update model parameters according to model_loss" meta_grad = torch.autograd.grad(model_loss, self.split_model_parameters()[0], create_graph=True) for k, weight in enumerate(self.split_model_parameters()[0]): weight.fast = weight - self.model_optim.param_groups[0]['lr']*meta_grad[k] meta_grad = [g.detach() for g in meta_grad] What's the purpose of adding "create_graph=True"? Why the weight.fast is updated rather than weight? Does this have anything to do with "ft_loss.backward()"?

  Could you please give me more detailed explanation? Thanks!
fikry102 commented 2 years ago

I create a QQ group whose group number is 693337454. Anyone interested in these problems is welcome to join in us for further dicussion.