VITA-Group / TENAS

[ICLR 2021] "Neural Architecture Search on ImageNet in Four GPU Hours: A Theoretically Inspired Perspective" by Wuyang Chen, Xinyu Gong, Zhangyang Wang
MIT License
167 stars 31 forks source link

NTK calculation incorrect for networks with multiple outputs? #11

Closed awe2 closed 3 years ago

awe2 commented 3 years ago

Howdy!

In: https://github.com/VITA-Group/TENAS/blob/main/lib/procedures/ntk.py

on line 45:

logit[_idx:_idx+1].backward(torch.ones_like(logit[_idx:_idx+1]), retain_graph=True)

I am confused about your calculation of the NTK, and believe that you may be misusing the first argument of the torch.Tensor.backward() function.

E.g.: when playing with the codebase with a very small 8 parameter network with 2 outputs:

class small(torch.nn.Module):
    def __init__(self,):
        super(small, self).__init__() 
        self.d1 = torch.nn.Linear(2,2,bias=False)
        self.d2 = torch.nn.Linear(2,2,bias=False)
    def forward(self, x):
        x = self.d1(x)
        x = self.d2(x)
        return x

Where for this explanation I have modified to:

gradient = torch.ones_like(logit[_idx:_idx+1])
gradient[0,0] = a
gradient[0,1] = b
logit[_idx:_idx+1].backward(gradient, retain_graph=True)

whereby J I mean your 'grad' list for a single network:

e.g.: lines 45 & 46:

grads = [torch.stack(_grads, 0) for _grads in grads]
ntks = [torch.einsum('nc,mc->nm', [_grads, _grads]) for _grads in grads]
print('J: ',grads)

for

gradient[0,0] = 0
gradient[0,1] = 1

J: [tensor([[-0.6255, -0.5019, 0.1758, 0.1411, 0.0000, 0.0000, -0.0727, -0.4643], [ 0.9368, -0.0947, -0.2633, 0.0266, 0.0000, 0.0000, 0.0955, -0.0812]])]

=======

for

gradient[0,0] = 1
gradient[0,1] = 0

J: [tensor([[ 0.1540, 0.1236, -0.6473, -0.5194, -0.0727, -0.4643, 0.0000, 0.0000], [-0.2307, 0.0233, 0.9694, -0.0980, 0.0955, -0.0812, 0.0000, 0.0000]])]

=======

for

gradient[0,0] = 1
gradient[0,1] = 1

J: [tensor([[-0.4715, -0.3783, -0.4715, -0.3783, -0.0727, -0.4643, -0.0727, -0.4643], [ 0.7061, -0.0714, 0.7062, -0.0714, 0.0955, -0.0812, 0.0955, -0.0812]])]

"""

And so you can verify that your code is adding the two components together to get the last result.

The problem is that your Jacobian should have size: number_samples x [(number_outputs x number_weights)] ; See your own paper, page 2, where you show that the Jacobian's components are defined on the subscript i, the ith output of the model.

If I am right, then any network that has multiple outputs would have their NTK values incorrectly calculated, would have a time and memory footprint that is systematically reduced by the fact that these gradients are being pooled together.

chenwydj commented 3 years ago

Hi @awe2,

Thanks for your question and interest in our work!

Yes, I understand your concern. And you are right, what I am doing is equivalent to summing up the output dim of the logit. In this case, it means I treat the output of the network as the sum of logit, instead of a multi-dim output. It will be much faster than back-propagating through each output dim, and it also works well. It is definitely possible to expand the NTK into [num_samples num_out_dim] x [num_samples num_out_dim]

Hope that helps!

awe2 commented 3 years ago

It is certainly faster, but I can calculate the NTK by hand for my simple network, and the results aren't the same. I'm not a math whiz, do you expect that this transformation leaves the value of "condition number" unchanged?

awe2 commented 3 years ago

If the output of the network is the sum of logit instead of the logits themselves, would the neural architectures you are searching over have the same response? i.e., I see an immediate application value to searching over architectures for image classification where we are (naively) concerned with networks that output logits, but I don't know the value of a search over networks that predict the sum of logits?

I'm relatively new to the field-- is there something I am overlooking?

chenwydj commented 3 years ago

Thanks for your questions.

  1. I am not saying that "sum up logits then backpropagate" or "backpropagate each output-dim" will give you the same condition number. I highly believe they will be different. The core difference of these two ways in the NAS setting (i.e. rank the architectures) is: which one gives better correlation? By treat the sum of logit as the function's single 1D output, or treat the neural network as a function of multi-dim outputs. Our paper did not have an answer to that. We just show that NTK of networks with 1D output shows a strong correlation of its classification accuracy.
  2. Different architectures will give different responses even if I treat them as 1D output functions, as demonstrated in our Fig.1 and our experiments. Again, I agree with you that calculating NTK of multi-dim outputs is doable.
awe2 commented 3 years ago

ah, Perfect, I think I understand. Thanks!

j0hngou commented 1 year ago

According to this paper, the pseudo-NTK (sum of logits) converges to the true empirical NTK at initialization for any network with a wide enough final layer.