ynanwu / MetaGCD

10 stars 1 forks source link

The function meta_update_model #4

Open lichong952012 opened 11 months ago

lichong952012 commented 11 months ago

In contrastive_learning_based_MAML.py, meta_update_model() takes 4 positional arguments but 5 were given, how to update model and head?

mashijie1028 commented 8 months ago

Hi, I tried to reproduce this repo and re-write the function meta_update_model() as follows:

def meta_update_model(projection_head, model, optimizer, loss, gradients):
    hooks = []
    for (k, v) in projection_head.named_parameters():
        def get_closure():
            key = k

            def replace_grad(grad):
                return gradients[key]

            return replace_grad

        if v.requires_grad:

    for (k, v) in model.named_parameters():
        def get_closure():
            key = k

            def replace_grad(grad):
                return gradients[key]

            return replace_grad

        if v.requires_grad:

    # Compute grads for current step, replace with summed gradients as defined by hook


    # Update the net parameters with the accumulated gradient according to optimizer

    # Remove the hooks before next training phase
    for h in hooks:

The code above has addressed my issue. Hope this works for you, too.