GeorgeCazenavette / mtt-distillation

Official code for our CVPR '22 paper "Dataset Distillation by Matching Training Trajectories"
https://georgecazenavette.github.io/mtt-distillation/
Other
395 stars 55 forks source link

Unrolled optimization #17

Closed vadimkantorov closed 1 year ago

vadimkantorov commented 2 years ago

Hi!

Do I understand correctly that the grand loss at the end will backprop through grad of grad of grad, e.g. not double backward but 20th order backward?

I.e. student_params[5] depends on student_params[4] and grad(loss(target; student_params[4]) and same goes further and we'll have in the computation branch a path that goes through all 5 grad computations

GeorgeCazenavette commented 2 years ago

Hi!

I think your intuition is correct, but it's a pretty confusing thing to put into words in general.

So here's a very rough sketch of the forward graph:

image

Does this match what you were thinking?