D-X-Y / AutoDL-Projects

Automated deep learning algorithms implemented in PyTorch.
MIT License
1.56k stars 281 forks source link

Questions about DARTS #99

Open buttercutter opened 3 years ago

buttercutter commented 3 years ago
  1. For DARTS complexity analysis, anyone have any idea how to derive the (k+1)*k/2 expression ? Why 2 input nodes ? How will the calculated value change if graph isomorphism is considered ? Why "2+3+4+5" learnable edges ? If there is lack of connection, the paper should not add 1 which does not actually contribute to learnable edges configurations at all ?

  2. Why need to train the weights for normal cells and reduction cells separately as shown in Figures 4 and 5 below ?

  3. How to arrange the nodes such that the NAS search will actually converge with minimum error ? Note: Not all nodes are connected to each and every other nodes

  4. Why is GDAS 10 times faster than DARTS ?

DARTS_normal_reduction_cells

DARTS_complexity_analysis

D-X-Y commented 3 years ago

Thanks for pointing out these questions.

(1). (k+1)k/2 is because for the k-th node, you have (k+1) preceding nodes. Selecting two from them has C(K+1, 2) possibilities. 2 input nodes are pre-defined according to human expert's experience. If isomorphism is considered, you need another way to represent this DAG. Before pruning the fully-connected graph into "2-input-nodes version", each node has (k+1) preceding nodes and has (k+1) edges -> (1+1) + (2+1) + (3+1) + (4+1) = 14 learable edges.

(2). We hypothesis the normal cell and reduction cell will have a very different topology structure

(3). No theoretical guarantee.

(4). Because for each iteration, DARTS needs to weighted-sum the architecture parameters and the outputs of every candidate operation -> O(N), but GDAS only needs to "sample" one candidate operation -> O(1).

buttercutter commented 3 years ago

Why for the k-th node, you have (k+1) preceding nodes. ?

D-X-Y commented 3 years ago

Because for each cell, they also allow the output of two previous cells as inputs, so for the 1-th first node in a cell, its preceding nodes are [last-cell-outputs, second-last-cell-outputs]. For the second node, it is: [last-cell-outputs, second-last-cell-outputs, first-node-outputs]

buttercutter commented 3 years ago

For second node, what is the difference between last-cell-outputs and first-node-outputs ?

image

Solution:

Each intermediate state, 0-3, is connected to each previous intermediate state as well as 
the output of the previous two cells, c_{k-2} and c_{k-1} (after a preprocessing layer).

image

D-X-Y commented 3 years ago

The last-cell-outputs is the output of green box c_{k-1}. The first-node-outputs is the output of blue box 0.

buttercutter commented 3 years ago

if you add gumbel distributed noise to logits and take the argmax, the gumbel noise is the exact right distribution that it is the same as softmaxing the logits and sampling from the discrete distribution defined by those probabilities

Someone told me that the above, but I am not familiar with gumbel and how it actually helps to speed up GDAS with respect to DARTS. I suppose it is the gumbel-max trick mentioned in the paper. I do not quite understand expressions (3) and (5) in the GDAS paper.

D-X-Y commented 3 years ago

You could have a look at our code: https://github.com/D-X-Y/AutoDL-Projects/blob/main/lib/models/cell_searchs/search_model_gdas.py#L89

buttercutter commented 3 years ago

@D-X-Y Could you comment on this reply on your Gumbel-Max code implementation ?

@Unity05 was suggesting to use softargmax

Unity05 commented 3 years ago

Hi, I was just explaining that the temperature you're using uses the same basic idea as softargmax.

buttercutter commented 3 years ago

@D-X-Y

in your coding, would you be able to describe how the logic of hardwts = one_h - probs.detach() + probs is used in the forward search function feature = cell.forward_gdas(feature, hardwts, index) ?

I mean the computation logic for hardwts is a bit weird or strange.
Why hardwts need to make use of both one_h and probs ? Why one of the probs need detach() ?

Besides, why would gumbel-max computation need a while loop ? I suppose you are using Gumbel(0, 1) ?

How exactly gumbel-max transforms equation (3) into equation (5) ?

buttercutter commented 3 years ago

For the question on hardwts , see the note section inside https://pytorch.org/docs/stable/nn.functional.html#gumbel-softmax

The main trick for hard is to do y_hard - y_soft.detach() + y_soft

It achieves two things: 
- makes the output value exactly one-hot (since we add then subtract y_soft value) 
- makes the gradient equal to y_soft gradient (since we strip all other gradients)

@D-X-Y by the way, why PNASNet mention Note that we learn a single cell type instead of distinguishing between Normal and Reduction cell. ?

image

buttercutter commented 3 years ago

Solution:

we do not distinguish between Normal and Reduction cells, 
but instead emulate a Reduction cell by using a Normal cell with stride 2

So, in this case, I suppose I could use only single type of weights for both normal cells and reduction cells ?

As for algorithm 1, how is A different from W ? Note: The corresponding notation meaning explanation after equations (3) and (4) of the paper is very confusing to me.

image

D-X-Y commented 3 years ago

@D-X-Y

in your coding, would you be able to describe how the logic of hardwts = one_h - probs.detach() + probs is used in the forward search function feature = cell.forward_gdas(feature, hardwts, index) ?

I mean the computation logic for hardwts is a bit weird or strange. Why hardwts need to make use of both one_h and probs ? Why one of the probs need detach() ?

Besides, why would gumbel-max computation need a while loop ? I suppose you are using Gumbel(0, 1) ?

How exactly gumbel-max transforms equation (3) into equation (5) ?

Sorry for the late reply, I'm a little bit busy these days.

hardwts = one_h - probs.detach() + probs aims to make hardwts has the same gradients as probs yet still keeps the one-hot values -- one_h. The while loop is a trick added by myself, which is to avoid very rare cases of NAN

D-X-Y commented 3 years ago

For the question on hardwts , see the note section inside https://pytorch.org/docs/stable/nn.functional.html#gumbel-softmax

The main trick for hard is to do y_hard - y_soft.detach() + y_soft

It achieves two things: 
- makes the output value exactly one-hot (since we add then subtract y_soft value) 
- makes the gradient equal to y_soft gradient (since we strip all other gradients)

@D-X-Y by the way, why PNASNet mention Note that we learn a single cell type instead of distinguishing between Normal and Reduction cell. ?

image

Yes, I borrow the idea of how to implement gumbel from PyTorch with a few modifications.

For PNAS, you may need to email their authors for the detailed reasons.

D-X-Y commented 3 years ago

Solution:

we do not distinguish between Normal and Reduction cells, 
but instead emulate a Reduction cell by using a Normal cell with stride 2

So, in this case, I suppose I could use only single type of weights for both normal cells and reduction cells ?

As for algorithm 1, how is A different from W ? Note: The corresponding notation meaning explanation after equations (3) and (4) of the paper is very confusing to me.

image

Yes, in this case, the architecture weights for normal cells and reduction cells are shared. A is the architecture weights -- the logits assigned for each candidate operation. W is the weights of the supernet -- the weights for convolution layers, etc.

buttercutter commented 3 years ago

@D-X-Y I am bit confused with the difference between cell and node

Edit: I think I got it now. A single cell contains 4 distinct nodes

By the way, in Algorithm 1, why GDAS updates W before A ?

D-X-Y commented 3 years ago

@D-X-Y I am bit confused with the difference between cell and node

Edit: I think I got it now. A single cell contains 4 distinct nodes

By the way, in Algorithm 1, why GDAS updates W before A ?

I feel it does not matter? Updating W, A, W, A, W, A or A, W, A, W, A, W would not make a big difference?

buttercutter commented 3 years ago

For GDAS, would https://networkx.org/documentation/stable/tutorial.html#multigraphs be suitable for both forward inference and backward propagation ?

D-X-Y commented 3 years ago

I'm not familiar with networkx and can not comment on that.

buttercutter commented 3 years ago

@D-X-Y I am confused as in how https://github.com/D-X-Y/AutoDL-Projects/blob/main/xautodl/models/cell_searchs/search_model_gdas.py implemented multiple parallel connections between nodes

buttercutter commented 3 years ago

@D-X-Y I am confused as in how equation (7) is an approximation of equation (5) as described in gdas paper ?

D-X-Y commented 3 years ago

@promach The difference between $h$ in Eq.(5) and Eq.(7) is that:

As you run Eq.(5) infinite times, and run Eq.(7) infinite times, their average results should be very close.

buttercutter commented 3 years ago

@D-X-Y in normal backpropagation, there is only a single edge in between two nodes.

However in GDAS, there are multiple parallel edges in between two nodes.

So, how to perform backpropagation for GDAS or more generally, Network Architecture Search (NAS) ?

image

buttercutter commented 3 years ago

For https://github.com/D-X-Y/AutoDL-Projects/issues/99#issuecomment-845789377 , how do you actually update both W and A simultaneously in a single epoch ?

Could you point me to the relevant code for the update portion ?
Did you use two def forward() functions for W and A since two disjoint sets are used ?

GDAS algorithm

D-X-Y commented 3 years ago

@promach , at a single iteration, we will first update W and then update A. Please see the codes here: https://github.com/D-X-Y/AutoDL-Projects/blob/main/exps/NAS-Bench-201-algos/GDAS.py#L49

buttercutter commented 3 years ago

If update W first, then only update A , the question is should I train the convolution kernel weights W based on the trained best edges for A ?

D-X-Y commented 3 years ago

I feel it does not matter for the order of W and A. As if you look at multiple iterations, it will be W -> A -> W -> A ->W -> A -> W -> A -> W -> A -> W -> A .... Whether the first one is W or A would not make a big difference.

buttercutter commented 3 years ago

the issue lingering in my head is that if W is to optimized FIRST, should W be trained under which exact A result ?

D-X-Y commented 3 years ago

What do you mean by exact A?

buttercutter commented 3 years ago

If W is trained FIRST, whichA should W training process uses as W's architecture ?

D-X-Y commented 3 years ago

A is a set of variables indicating the architecture encoding. There only one A and no other options?

D-X-Y commented 3 years ago

You could have a look at the codes here and Would you mind clarifying what do you think the codes should be?

buttercutter commented 3 years ago

Let me rephrase my question, how do you define base_inputs and arch_inputs ?

It seems to be different from how DARTS paper originally proposed. See equations (5) and (6) of DARTS paper

image

D-X-Y commented 3 years ago

base_inputs are a batch of samples from the training data, arch_inputs are a batch of samples from the validation data.

Yes, following the DARTS paper, I should switch the order of updating W and A.

buttercutter commented 3 years ago

during training for W, should I use a particular found architecture inside that particular epoch ? OR should I use the whole supernet ?

D-X-Y commented 3 years ago

It depends on the NAS algorithm. For DARTS, they use the whole supernet. For GDAS, we use an architecture candidate randomly sampled based on A.

buttercutter commented 3 years ago

For GDAS, we use an architecture candidate randomly sampled based on A.

The candidate is chosen using gumbel-argmax (equation (5) and (6) of GDAS paper) , instead of chosen randomly. Please correct me if wrong.

D-X-Y commented 3 years ago

gumbel-argmax is a kind of random? because the $o_{k}$ is randomly sampled from Gumbel(0, 1).

buttercutter commented 3 years ago

For https://github.com/D-X-Y/AutoDL-Projects/issues/99#issuecomment-835802887 , there are two types of outputs from the blue node.

One type of (multiple edges) output connects to the input of the other blue nodes ?

Another type of (single edge) output connects directly to the yellow node ?

buttercutter commented 3 years ago

It seems that both ENAS and PNAS just perform add and concat operations for the connection to the output node

image

image

buttercutter commented 3 years ago

@D-X-Y I implemented a draft code on GDAS,

However, could you advise whether this edge weight training epoch mechanism will actually work for GDAS ?

D-X-Y commented 3 years ago

For #99 (comment) , there are two types of outputs from the blue node.

One type of (multiple edges) output connects to the input of the other blue nodes ?

Another type of (single edge) output connects directly to the yellow node ?

Yes, you are right~

D-X-Y commented 3 years ago

It seems that both ENAS and PNAS just perform add and concat operations for the connection to the output node

@promach Yes. DARTS also uses add for the intermediate nodes and concat for the final output node (https://github.com/D-X-Y/AutoDL-Projects/blob/main/xautodl/models/cell_searchs/search_cells.py#L251).

D-X-Y commented 3 years ago

@D-X-Y I implemented a draft code on GDAS,

However, could you advise whether this edge weight training epoch mechanism will actually work for GDAS ?

I personally feel the implementations are incorrect. I havn't fully checked the codes, but at least, the input for every cell/node should not be the same forward_edge(train_inputs).

buttercutter commented 3 years ago

How to code the forward pass function correctly for edge weight training ?

    # self-defined initial NAS architecture, for supernet architecture edge weight training
    def forward_edge(self, x):
        self.__freeze_f()
        self.__unfreeeze_w()

        return self.weights

Note: This is for training step 2 inside Algorithm 1 of DARTS paper

D-X-Y commented 3 years ago

why do we return self.weights? Instead of return the value of using weights on x? The logics of freeze and unfreeze are correct, but I do not understand return ...

buttercutter commented 3 years ago

I am not sure how to train edge weights, hence the question about def forward_edge()

Besides, I also suspect the forward pass function for architecture weight (step 1 inside DARTS Algorithm 1) might be incorrect as well because it only trains the neural network function's internal weight parameters instead of architecture weight.

Note: self.f(x) is something like nn.Linear() , nn.Conv2d

    # for NN functions internal weights training
    def forward_f(self, x):
        self.__unfreeze_f()
        self.__freeeze_w()

        # inheritance in python classes and SOLID principles
        # https://en.wikipedia.org/wiki/SOLID
        # https://blog.cleancoder.com/uncle-bob/2020/10/18/Solid-Relevance.html
        return self.f(x)
buttercutter commented 3 years ago

Sorry, I misinterpreted the purpose of the two forward pass functions.

forward_edge() is for architecture weights (step 1), while forward_f() is for NN function's internal weights (step 2).

However, I am still not sure how to code for def forward_edge(self, x)

buttercutter commented 3 years ago

@D-X-Y For ordinary NN training operation, we have some feature maps outputs.

However for the edge weights (NAS) training operation, there are no feature maps outputs though. So, what should be fed into x for forward_edge(x) ?

buttercutter commented 2 years ago

Is using nn.Linear() to train edge weights feasible for GDAS on a small GPU ?

    # self-defined initial NAS architecture, for supernet architecture edge weight training
    def forward_edge(self, x):
        self.__freeze_f()
        self.__unfreeeze_w()

        return self.linear(x)