alinlab / L2T-ww

Learning What and Where to Transfer (ICML 2019)
MIT License
250 stars 48 forks source link

Math details about "meta_backward" #8

Open MarSaKi opened 4 years ago

MarSaKi commented 4 years ago

can you give more math details about "meta_backward", I just read the code and Reverse_HG, but I couldn't understand your code.

wangxiang1099 commented 4 years ago

me too

hankook commented 4 years ago

For more details, we first provide correspondence between math notations (i.e., $x$) in ReverseHG paper and our code (i.e., x). See Algorithm 1 and Equation 11 in the paper.

Our code is following:

  1. Compute $\Phi(s,\lambda)$ (see L69-L73 and Equation 11).
  2. Compute inner-product between $\alpha$ (i.e., alpha_groups[-1]) and $\Phi(s,\lambda)$, then store the value into X (see L74-L82).
  3. Compute X's gradient by X.backward() (see L83-L84). This is same as Hessian-vector multiplication. Then, $\alpha B$ is accumulated into $g$ as described in Algorithm 1, and $\alpha A$ is stored in p.grad where p is a parameter tensor in $s$.
  4. Fix $\alpha$ (see L85-L93). This is only required when either weight decay (wd) or momentum (momentum) is not zero.
  5. After meta_backward, the gradient of $\lambda$ (i.e., $g$) is stored into corresponding p.grads. Thus, we just use source_optimizer.step() for updating $\lambda$.

I think our code is more easier to understand when using vanilla SGD (i.e., wd=0 and momentum=0).