quark0 / darts

Differentiable architecture search for convolutional and recurrent networks
https://arxiv.org/abs/1806.09055
Apache License 2.0
3.92k stars 843 forks source link

Two questions about code implementation #121

Closed TianQi-777 closed 5 years ago

TianQi-777 commented 5 years ago

Hello, this is a very good project. I am a beginner of CS, and I am very interested in this project, I read all the code. However, I have two questions about code.

Question 1: The question about the formula for the second-order approximation in the paper and code.

In the paper, w+ =w + ε▽w'L_val(w',α) and w- =w - ε▽w'L_val(w',α),how is ▽w'L_val(w',α) expressed in the code? I mean how to calculate ▽w'L_val(w',α) in the code? I don't think this value is calculated in the code?

I currently have 2 kinds of understanding:

  1. In the code, /cnn/architect.py I guess it should be vector = [v.grad.data for v in unrolled_model.parameters()] [Line 49] for this formula ▽w'L_val(w',α). However, I think v.grad.data is just a parameter that is manually updated in function _compute_unrolled_model [Line 20] (becoming virtual gradient step in the paper) , I have not found a code implementation that uses the validation set to derive w'.

  2. In the code, /cnn/architect.py Another guess is that ▽w'L_val(w',α) is calculated in unrolled_loss.backward() [Line 47], but the optimization parameter of the optimizer in unrolled_model is model.arch_parameters() [Line 17]. I don't think this step will calculate the derivative of L_val(w',α) to w'. In addition, the paper also shows that optimizing L_val(w',α) will fixed w, so how to calculate ▽w'L_val(w',α) in the code? Is it wrong with my understanding of Pytorch?

Question 2: In /cnn/architect.py [Line 76] Why not use ε=r=1e-2 directly and use ε=r/_concat(vector).norm(), it is a bit redundant to divide by _concat(vector).norm()?

I look forward to your reply, thank you very much.

TianQi-777 commented 5 years ago

I have solved the problem

  1. ▽w'L_val(w',α) is calculated in unrolled_loss.backward() [Line 47],I misunderstood the mechanism of pytorch
  2. The footnote of the paper has already explained this problem.