D-X-Y / AutoDL-Projects

Automated deep learning algorithms implemented in PyTorch.
MIT License
1.56k stars 281 forks source link

questions about forward propagation and backward propagation #33

Closed wwwanghao closed 4 years ago

wwwanghao commented 4 years ago

Thank you for making the search code publicly accessible, it is really a good material for Nas researchers. After reading the paper and reviewing the search code, there a few points which I may not clearly understand.

In the paper , you mentioned: since Eq. (3) needs to sample from a discrete probability distribution, we cannot back-propagate gradients. To allow back-propagation, we use the Gumbel-Max trick. During the acceleration part: During the backward procedure, we only back-propagate the gradient generated at the argmax.

First, set F1, F2, F3, F4 are functions between two nodes, the corresponding arch param are a1, a2, a3, a4, and their gumbel softmax are p1, p2, p3, p4. During forward, we sample the index with max prob, and get a one-hot vector with code: hardwts = one_h - probs.detach() + probs assume the one-hot arch weights are w1, w2, w3,w4, set the argmax index is 2. the forward code is as follow: weigsum = sum( weights[_ie] edge(nodes[j]) if _ie == argmaxs else weights[_ie] for _ie, edge in enumerate(self.edges[node_str]) ) according the code, the forward result is like: weightsum = w1 + w2F2+w3+w4

q1: Is the code acceleration version or not? According to the acceleration part in you paper, only need to backprop to argmax, which means only backprop to F2, and F1, F3, F4 are ignored? For the arch param, only backprop to w2, and w1, w3, w4 are ignored? From the code, it seems w1, w3, w4 are also backprop.

q2. Only forward argmax, but from the code, w1, w3, w4 are also added to the weightsum. Even though their values are zeros, which is the same to only calculate w2F2, what's the purpose to add w1, w3, w4? What will happen if we only calculate weightsum as w2F2?

q3. If the soft gumbel softmax is applied rather than the one-hot one, can we still calculate weightsum as w1 + w2F2+w3+w4? I think maybe not, because w and wF can be different order of magnitude.

These questions confused me a lot, it will be really helpful if you can kindly give me some suggestion. Thank you!

D-X-Y commented 4 years ago

@wwwanghao Sorry for the late reply. Q1. This is the acceleration version. Yes, gradients to F1/F3/F4 are ignored. We did not ignore w1,w3,w4 as it does not increase the computational costs.

Q2. The forward results are the same. However, during BP, there are still gradients on w1, w3, and w4.

Q3. No, because the forward results are different.

wwwanghao commented 4 years ago

Thank you for your kindly answer, that's really helpful. Thank you very much.

D-X-Y commented 4 years ago

@wwwanghao You are welcome.