khanrc / pt.darts

PyTorch Implementation of DARTS: Differentiable Architecture Search
MIT License
439 stars 108 forks source link

questions on v_grads = torch.autograd.grad(loss,v_alphas+v_weights) #27

Open cholihao opened 4 years ago

cholihao commented 4 years ago

in architect.py, Im confused about the following 3 lines of code: v_grads = torch.autograd.grad(loss, v_alphas + v_weights) dalpha = v_grads[:len(v_alphas)] dw = v_grads[len(v_alphas):] why does the gradient compute w.r.t (v_alphas+v_weights)? and the dalpha is retrieved from v_grads[:len(v_alphas)]. I thought it should be computed w.r.t v_alphas only based on equation (7). the other question is why can you get dalpha and dw from v_grads directly instead of doing autograd separately?