LFTNet.py, "update model parameters according to model_loss"
meta_grad = torch.autograd.grad(model_loss, self.split_model_parameters()[0], create_graph=True) for k, weight in enumerate(self.split_model_parameters()[0]): weight.fast = weight - self.model_optim.param_groups[0]['lr']*meta_grad[k] meta_grad = [g.detach() for g in meta_grad]
What's the purpose of adding "create_graph=True"? Why the weight.fast is updated rather than weight?
Does this have anything to do with "ft_loss.backward()"?
Could you please give me more detailed explanation? Thanks!
LFTNet.py, "update model parameters according to model_loss"
meta_grad = torch.autograd.grad(model_loss, self.split_model_parameters()[0], create_graph=True) for k, weight in enumerate(self.split_model_parameters()[0]): weight.fast = weight - self.model_optim.param_groups[0]['lr']*meta_grad[k] meta_grad = [g.detach() for g in meta_grad]
What's the purpose of adding "create_graph=True"? Why the weight.fast is updated rather than weight? Does this have anything to do with "ft_loss.backward()"?