ynanwu / MetaGCD

10 stars 1 forks source link

The function meta_update_model #4

Open lichong952012 opened 8 months ago

lichong952012 commented 8 months ago

In contrastive_learning_based_MAML.py, meta_update_model() takes 4 positional arguments but 5 were given, how to update model and head?

mashijie1028 commented 5 months ago

Hi, I tried to reproduce this repo and re-write the function meta_update_model() as follows:

def meta_update_model(projection_head, model, optimizer, loss, gradients):
    hooks = []
    for (k, v) in projection_head.named_parameters():
        def get_closure():
            key = k

            def replace_grad(grad):
                return gradients[key]

            return replace_grad

        if v.requires_grad:
            hooks.append(v.register_hook(get_closure()))

    for (k, v) in model.named_parameters():
        def get_closure():
            key = k

            def replace_grad(grad):
                return gradients[key]

            return replace_grad

        if v.requires_grad:
            hooks.append(v.register_hook(get_closure()))

    # Compute grads for current step, replace with summed gradients as defined by hook
    optimizer.zero_grad()

    loss.backward()

    # Update the net parameters with the accumulated gradient according to optimizer
    optimizer.step()

    # Remove the hooks before next training phase
    for h in hooks:
        h.remove()

The code above has addressed my issue. Hope this works for you, too.