khanrc / pt.darts

PyTorch Implementation of DARTS: Differentiable Architecture Search
MIT License
439 stars 108 forks source link

Missing Softmax for Genotype selection #16

Open ghost opened 5 years ago

ghost commented 5 years ago

Hi @khanrc,

first of all thanks for the recent and understandable implementation of DARTS. I really enjoyed using this code.

After some exploration and alternation of the code, I stumbled over the implementation of the parse(alpha, k, reduction=True) function. In there the most heavily weighted operations are selected. As I wondered why the "none" operation is omitted, I compared the code to the original implementation. The function there does not compare the original weights but rather the softmaxed ones. https://github.com/quark0/darts/blob/f276dd346a09ae3160f8e3aca5c7b193fda1da37/cnn/model_search.py#L154

Which seems to be quite crucial, since then there is more focus on the "none" operation. (For further explanation see this comment in the openreview of the paper).

Maybe you need to calculate the line wise softmax of the alpha values before calling the function or compute it in the function before searching the topK. Please correct me if I'm wrong.

Thanks in Advance Lucas

ghost commented 5 years ago

I constructed a small example for a case, where the connected nodes differ if you apply softmax:

import torch
import torch.nn as nn

# construct softmax operation
softi = nn.Softmax(-1)

# List of example operations to choose from
ops = ['Conv', 'Pool', 'None']

# introduce tensor (corresponds to the alpha weights)
# The tensor has a line for each possible iput hidden state
# and a row for every possible operation
# The last weight is per definition for the 'none' operation
normal = torch.tensor([[0.31, 0.28, 0.44], [0.30, 0.3, 0.4], [0.19, 0.51, 0.3]])

# softmax the introduced tensor
softmaxed = softi(normal)

# print the tensors
print('Normal weight tensor:')
print(normal)
print('Softmaxed weight tensor:')
print(softmaxed)
print('-----------------------------------------\n')

# find the best two edges of the normal vector (omit 'none operation')
normal_max_edgewise_values, normal_max_edgewise_indices = torch.topk(normal[:, :-1], 1)
normal_max_inputwise_values, normal_max_inputwise_indices = torch.topk(normal_max_edgewise_values.view(-1), 2)

# find the two best edges of the softmaxed vector (omit 'none operation')
softmaxed_max_edgewise_values, softmaxed_max_edgewise_indices = torch.topk(softmaxed[:, :-1], 1)
softmaxed_max_inputwise_values, softmaxed_max_inputwise_indices = torch.topk(softmaxed_max_edgewise_values.view(-1), 2)

# print the elements each methods chose
print('From the normal weight tensor we chose operations: ')
for i in normal_max_inputwise_indices:
    print(f'\t[{i}, {normal_max_edgewise_indices[i].item()}]')
    print('\tThis corresponds to: ' + ops[normal_max_edgewise_indices[i].item()] + f' on hidden state {i}.')
print('\nFrom the softmaxed weight tensor we chose operations: ')
for i in softmaxed_max_inputwise_indices:
    print(f'\t[{i}, {softmaxed_max_edgewise_indices[i].item()}]')
    print('\tThis corresponds to: ' + ops[softmaxed_max_edgewise_indices[i].item()] + f' on hidden state {i}.')

# as one can see from the output the used hidden state differ for this new hidden state

The ouput from this one is:

Normal weight tensor:
tensor([[0.3100, 0.2800, 0.4400],
        [0.3000, 0.3000, 0.4000],
        [0.1900, 0.5100, 0.3000]])
Softmaxed weight tensor:
tensor([[0.3216, 0.3121, 0.3663],
        [0.3220, 0.3220, 0.3559],
        [0.2863, 0.3942, 0.3195]])
-----------------------------------------

From the normal weight tensor we chose operations: 
    [2, 1]
    This corresponds to: Pool on hidden state 2.
    [0, 0]
    This corresponds to: Conv on hidden state 0.

From the softmaxed weight tensor we chose operations: 
    [2, 1]
    This corresponds to: Pool on hidden state 2.
    [1, 1]
    This corresponds to: Pool on hidden state 1.

Process finished with exit code 0

As you can see from the output the selection process selects two different operations on two differnt previous hidden states. Due to the increased weight of the 'none' operation in the last coloumn of row 0 (so hiddenstate 0) and the softmax, node 0 isn't connected anymore.

This mainly works for cases, where the sum of the rows of alpha don't add up to one, but as far as I'm concerned there is no restriction for the optimization to hold this constraint.

khanrc commented 5 years ago

Thanks to your pointing out! I agree with you. I do not think this makes a big impact on the final performance, but it's obviously a bug logically. I will fix it soon. Thanks again for finding this bug :)

ghost commented 5 years ago

I really appreciate your fast reply. I don't think it has a big impact on the performance either because my example is obviously an absolute borderline case. But I wanted to talk to you, since I was not sure, I got it right :)