Renovamen / metallic

A clean, lightweight and modularized PyTorch meta-learning library.
https://metallic-docs.vercel.app
MIT License
7 stars 0 forks source link

MAML or FOMAML implement #1

Open FYYFU opened 3 years ago

FYYFU commented 3 years ago

in the maml.py you compute the grad manully and map the grad from higher to the original model. but you only use the grad[-1] to the original model. is it works well?

                if meta_train == True:
                    # query_loss.backward()
                    outer_grad = torch.autograd.grad(query_loss / n_tasks, fmodel.parameters(time=0))
                    grad_list.append(outer_grad)

        # When in the meta-training stage, update the model's meta-parameters to
        # optimize the query losses across all of the tasks sampled in this batch.
        if meta_train == True:
            # apply gradients to the original model parameters
            apply_grads(self.model, grad_list[-1])                                      ---- about this line.
            # outer loop update

also in reptile.py

Renovamen commented 3 years ago

Hi! I think you are right, my bad, I'll fix it later. Thank you for pointing it out!

FYYFU commented 3 years ago

Hi! I think you are right, my bad, I'll fix it later. Thank you for pointing it out!

Thanks for your nice work! In the higher issue Link . He drop some code of higher to implement FOMAML. I wonder is it different from maually mapping the gradient to the original model? (I follow the issure's code setting and get a different result from yours code.) I think your code is the right setting, but the results show a little different from what i think.

his change in higher.optim.py:

 new_params = params[:]
 for group, mapping in zip(self.param_groups, self._group_to_param_list):
      for p, index in zip(group['params'], mapping):
          if self._track_higher_grads:
              new_params[index] = p
          else:
              new_params[index] = p.detach().requires_grad_()

to

new_params = params[:]
for group, mapping in zip(self.param_groups, self._group_to_param_list):
    for p, index in zip(group['params'], mapping):
        new_params[index] = p
Renovamen commented 3 years ago

Oh, that issue was also written by me 😂. In my experiments, there appears to be no difference between these two ways. Editing code in higher.optim.py allows gradients flow back to the original model, which in my opinion, is equal to manully mapping the gradients to their corrosponding paremeters of the original model. I don't want to edit the source code of higher, so I choose to map gradients manully in this project.

Could you please share me your experiment code and results?

FYYFU commented 3 years ago

Oh, that issue is also written by me joy. In my experiments, there appears to be no difference between these two ways. Editing code in higher.optim.py allows gredients flow back to the original model, which in my opinion, is equal to manully mapping the gredients to their corrosponding paremeters of the original model. I don't want to edit the source code of higher, so I choose to map gradients manully in this project.

Could you please share me your experiment code and results?

sorry... it my problem. i put the grad_list in the for loop. so it might only use the final grad.. (so it may prove that only use the final grad should not be a good idea. :) ) . thanks for your reply and thanks for your nice work again!

but this two way do have some different. If you just modified the higher( as you metioned in the issue.). The inner-step only can be set to 1. If set to 2 or bigger, error will occur. you can have a try :)

Renovamen commented 3 years ago

Now the accumulated inner loop gradients (instead of final gradients) is used now (3d835ba). Thank you!

but this two way do have some different. If you just modified the higher (as you metioned in the issue.). The inner-step only can be set to 1. If set to 2 or bigger, error will occur. you can have a try :)

I had a try. I am not sure whether modifying higher in that way will cause any other issues, but it seems that as least setting inner-step to 2 or bigger is okey. Can I see your code?

FYYFU commented 3 years ago

Now the accumulated inner loop gradients (instead of final gradients) is used now (3d835ba). Thank you!

but this two way do have some different. If you just modified the higher (as you metioned in the issue.). The inner-step only can be set to 1. If set to 2 or bigger, error will occur. you can have a try :)

I had a try. I am not sure whether modifying higher in that way will cause any other issues, but it seems that as least setting inner-step to 2 or bigger is okey. Can I see your code?

The error is:

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

The pseudo-code is:

for batch_list in dataloader:
    for support_batch, query_batch in batch_list:

        with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=False, track_higher_grads=False) as (fast_model, diffopt):
            # inner-step is used here.
            for i in range(inner_steps):
                output = fast_model(support_batch)
                support_loss = loss_function(output)
                diffopt.step(loss)

            output = fast_model(query_batch)
            query_loss = loss_function(output)
            query_loss.backward()

the batch_list have more than one batch.

i only modified the higher.optim.py 252-258 lines. But the variable self._track_higher_grad also appears in the 232 line. should 232 line be changed to True, when self._track_higher_grad = False? ( just a guess ... :) )

Renovamen commented 3 years ago

That's strange, your code seems the same as mine, but my code works well...

And create_graph (self._track_higher_grad) in line 232 is used for computing higher-order derivatives, so it should be set to False when using FOMAML (which use first-order derivatives only). I don't think it needed to be changed.

FYYFU commented 3 years ago

That's strange, your code seems the same as mine, but my code works well...

And create_graph (self._track_higher_grad) in line 232 is used for computing higher-order derivatives, so it should be set to False when using FOMAML (which use first-order derivatives only). I don't think it needed to be changed.

I also think my code is the same as yours.. Maybe there is something wrong in my code and i will have a carefully check.. if i fine the mistake, i will give you a feedback. Thanks for your reply.. :> Wish everything goes well with your work & coding !

Renovamen commented 3 years ago

I also think my code is the same as yours.. Maybe there is something wrong in my code and i will have a carefully check.. if i fine the mistake, i will give you a feedback. Thanks for your reply.. :> Wish everything goes well with your work & coding !

Okey. Same to you!