learnables / learn2learn

A PyTorch Library for Meta-learning Research
http://learn2learn.net
MIT License
2.63k stars 351 forks source link

Parameters of cloned learner is not updating #420

Open xunil17 opened 11 months ago

xunil17 commented 11 months ago

Hello, I am running the toy_example.py and printing out the loss and adapt_loss as shown below.

    for t in range(TASKS_PER_STEP):
            # Sample a task
            task_params = task_dist.sample()
            mu_i, sigma_i = task_params[:DIM], task_params[DIM:]

            # Adaptation: Instanciate a copy of model
            learner = maml.clone()
            proposal = learner()

            # Adaptation: Compute and adapt to task loss
            loss = (mu_i - proposal.mean).pow(2).sum() + (sigma_i - proposal.variance).pow(2).sum()
            learner.adapt(loss)
            print(loss)

            # Adaptation: Evaluate the effectiveness of adaptation
            adapt_loss = (mu_i - proposal.mean).pow(2).sum() + (sigma_i - proposal.variance).pow(2).sum()
            print(adapt_loss)

            # Accumulate the error over all tasks
            step_loss += adapt_loss

However, I'm receiving the same adapt_loss and loss at every timestep. I'm also looking at the parameters inside the maml, proposal, and learner. I see that the proposal and maml have the same parameters but the learner.module._parameters are different. Am I understanding that the proposal and maml model should have different parameters after learner.adapt is called?

Thank you so much!

ImahnShekhzadeh commented 10 months ago

However, I'm receiving the same adapt_loss and loss at every timestep.

This is weird, indeed you should receive different loss values. Maybe you can proposal by learner() everywhere and see what happens?

I see that the proposal and maml have the same parameters but the learner.module._parameters are different.

Where exactly are you doing the print statement?

Am I understanding that the proposal and maml model should have different parameters after learner.adapt is called?

I would expect that too, yes..