Open yurunsheng1 opened 5 years ago
Hi, First thank you for providing us such a nice work!
But I meet a question and really need you help:
In your MeLU.py lines 71-79:
grad = torch.autograd.grad(loss, self.model.parameters(), create_graph=True) # local update for i in range(self.weight_len): if self.weight_name[i] in self.local_update_target_weight_name: self.fast_weights[self.weight_name[i]] = weight_for_local_update[i] - self.local_lr * grad[i] else: self.fast_weights[self.weight_name[i]] = weight_for_local_update[i] self.model.load_state_dict(self.fast_weights) query_set_y_pred = self.model(query_set_x)
I understand this is the standard MAML approach (inner loop).
However, the function load_state_dict() will erase (break) the gradient (https://discuss.pytorch.org/t/loading-a-state-dict-seems-to-erase-grad/56676) and thus the global update will no longer consider the local update gradient in the final optimization. So, create_graph=True may not work and the algorithm may not be standard MAML any more. I am wondering whether I lose any insight behind that.
Looking forward to your reply!
I believe you are right and the original code is wrong.
Hi, First thank you for providing us such a nice work!
But I meet a question and really need you help:
In your MeLU.py lines 71-79:
I understand this is the standard MAML approach (inner loop).
However, the function load_state_dict() will erase (break) the gradient (https://discuss.pytorch.org/t/loading-a-state-dict-seems-to-erase-grad/56676) and thus the global update will no longer consider the local update gradient in the final optimization. So, create_graph=True may not work and the algorithm may not be standard MAML any more. I am wondering whether I lose any insight behind that.
Looking forward to your reply!