Open lichong952012 opened 11 months ago
Hi, I tried to reproduce this repo and re-write the function meta_update_model()
as follows:
def meta_update_model(projection_head, model, optimizer, loss, gradients):
hooks = []
for (k, v) in projection_head.named_parameters():
def get_closure():
key = k
def replace_grad(grad):
return gradients[key]
return replace_grad
if v.requires_grad:
hooks.append(v.register_hook(get_closure()))
for (k, v) in model.named_parameters():
def get_closure():
key = k
def replace_grad(grad):
return gradients[key]
return replace_grad
if v.requires_grad:
hooks.append(v.register_hook(get_closure()))
# Compute grads for current step, replace with summed gradients as defined by hook
optimizer.zero_grad()
loss.backward()
# Update the net parameters with the accumulated gradient according to optimizer
optimizer.step()
# Remove the hooks before next training phase
for h in hooks:
h.remove()
The code above has addressed my issue. Hope this works for you, too.
In contrastive_learning_based_MAML.py, meta_update_model() takes 4 positional arguments but 5 were given, how to update model and head?