Closed kongxz closed 5 years ago
Sorry for the confusion. That file is DARTS instead of our algorithm, we did not release the searching codes of GDAS. The main difference between GDAS and DARTS is that we use Gumbel-softmax with an acceleration trick to allow only one candidate CNN is used during forwarding, while can still back-propagate to the architecture parameters.
What do you mean by allow only one candidate CNN is used during forwarding, while can still back-propagate to the architecture parameters.
?
For example, we have four paths (path1, path2, path3, path4, each path corresponds to an architecture parameter) from x_node to y_node, and each path can represent a candidate CNN. During forward, we only calculate one path to get y_node, but our GDAS can BP gradients to all architecture parameters for four paths.
alright, how exactly is that acceleration trick
mentioned previously being implemented in actual coding ?
Welcome to read Sec.3.2 in https://arxiv.org/pdf/1910.04465.pdf
you mean Gumbel-Max trick ?
If yes, how exactly does it allow only one candidate CNN is used during forwarding
?
Solution:
because the one-hot or argmax
(equation 5 inside paper) is used in the forward pass
then because the softmaxed version (equation 7 inside paper) is used in the backward pass, all of the other paths' parameters can get nonzero (approximated) gradients
in the forward procedure, we only need to calculate the function Farg max(hi,j ) . During the backward procedure, we only back-propagate the gradient generated at the arg max(h̃i,j ).
@D-X-Y
In the quoted text above inside gdas paper, I have few questions :
argmax
operation is not differentiable in pytorch ?W
training) and validation (for A
training) datasets ?
Thanks for sharing the code.
I have a question about the implementation difference from DARTS. The training code looks like very similar to DARTS(https://github.com/quark0/darts).
As you mentioned in the paper, "2. Instead of using the whole DAG, GDAS samples one sub-graph at one training iteration, accelerating the searching procedure. Besides, the sampling in GDAS is learnable and contributes to finding a better cell."
But in the forward function of MixedOp, the output is just the weighted sum of all ops, same as DARTS.
def forward(self, x, weights): return sum(w * op(x) for w, op in zip(weights, self._ops))
So, can you point out the code that "samples one sub-graph at one training iteration"? Thanks.