amzn / metalearn-leap

Original PyTorch implementation of the Leap meta-learner (https://arxiv.org/abs/1812.01054) along with code for running the Omniglot experiment presented in the paper.
Apache License 2.0
147 stars 36 forks source link

Leap usage #5

Open semin-park opened 4 years ago

semin-park commented 4 years ago

In your leap usage example, leap.update() is called within the for loop. However, shouldn't it be called right after the for loop as well? Since leap.update() does nothing but register _prev_state and _prev_loss on the first round, if the for loop is run five times, calling leap.update() five times will only update the first four gradient paths, missing out on the last (fifth) gradient step.

I think the final loss (for $\theta_K$) as well as leap.update() should be computed after the for loop. What do you guys think?

https://github.com/amzn/metalearn-leap/blob/1436ee3029bf5de7cd8e317b9bbf56ff02f46a6c/src/leap/leap/leap.py#L36-L60

flennerhag commented 4 years ago

Hey, that's a great question!

You can definitely do what you propose, essentially, the difference from that and the above snippet is that we sneak in an extra step on the final loss since it's already paid for :-)

Here's how they compare:

for i in K:
    # updates

# what you propose
final_loss = criterion(model(final_x), final_y)
leap.update(final_loss, model)

# what we do: an extra update
final_loss.backward()
optimizer.step()

In the end, it won't matter unless K is very low.