bfshi / ARML_Auxiliary_Task_Reweighting

Code for our paper "Auxiliary Task Reweighting for Minimum-data Learning" (NeurIPS 2020)
https://sites.google.com/view/auxiliary-task-reweighting/home
17 stars 2 forks source link

difference between implementation and paper's algorithm #1

Open ChangLee0903 opened 2 years ago

ChangLee0903 commented 2 years ago

Hi @bfshi,

very appreciate your work. But I noticed that there might be some differences between your implementation and the proposed paper. My problems of the implementation details are listed as follow:

  1. The equation A11 in your paper is a result of expectation( from theta samples), but the alpha is updated by only one sample(theta_t) in algorithm 1. How can I realize this? Is it possible to explain theta_t as the result of expectation?
  2. In the proposed paper, the task weights are expected to update after the theta_t is once decided. However, the task weights are updated once in 20 iterations. How does this affect the performance? How does the result change if you set different updating periods?
  3. In lines 208, 209 of your code, there is another scheme to update the alpha which seems to be ignored in your paper. Could you explain this?

BTW, your paper has a typo in the fourth line of A45, which misses the Z'(alpha) term.

Thanks, Chi-Chang Lee.

bfshi commented 2 years ago

Hi Chi-Chang! Thanks for the interest in our work. To answer your questions:

  1. We interpret theta_t in each iteration as a sample from the joint distribution p_J by Langevin dynamics. In other words, each time we do an update of theta, we get a sample from p_J. In our code, we update task weights every time we get a sample theta_t (it's like only drawing one sample from p_J to use it as expectation). But you can also collect several samples through multiple iterations and then update the task weights based on these samples. A practical way is to collect the gradients w.r.t the task weights in each iteration (instead of collecting theta_t itself, because you need backprop to get the gradient from theta_t, but the computational graph would be cleared after each iteration), and then update task weights every n iterations using these gradients.
  2. Yes, we set the 20-iteration frequency for efficiency. You can also update the weights more frequently, but it may affect the stability of training (you can try it out!).
  3. The lines 208-209 is for normalizing the task weights onto the simplex (see Sec. 2.4 in the paper)
  4. Since Z'(alpha) term is independent on theta, the gradient of it is zero. So the last two lines of A45 are equivalent.

Hope these have addressed your concerns! Feel free to follow up if you have any further questions!

ChangLee0903 commented 2 years ago

Many thanks for your answer!

Just two questions left:

  1. I notice that you choose Adam to train your model, does this setting align to the sampling of "P^{J}(theta)"? The momentum part seems to make the sampling scheme different.
  2. Do all the experiments follow the same recipe? If not, could you provide your recipes for them?

best, Chi-Chang Lee

bfshi commented 2 years ago

Hi Chi-Chang,

  1. yes, the original sgld uses SGD for sampling. Adam is slightly different because of the momentum. But you only need to adjust the noise injected into the gradient. The noise term is already adjusted for Adam in the code.
  2. The experimental settings are reported in the appendix B. You can have a look of that part.

best, Baifeng

ChangLee0903 commented 2 years ago

Hi @bfshi,

I just noticed that the argument "--lagrange" was missed in your command of the readme. In my understanding, the actual ARML algorithm should take this setting as applying it. For this, I have three questions:

  1. Have you compared the results with or without noise injection?
  2. Do the comparisons(other methods) take the injection process?
  3. Are all the ARML results on your paper table with the injection process?

By the way, the injection will be shut down after 50000 steps. Is this for the training stability? Are all the ARML results follow this setup?