Closed mutiann closed 1 month ago
Hi!
Thank you so much for the kind words and for the questions! Hoping to be as informative as possible, I'll gladly reply in order:
The goal is to prune the network weights that contribute the least to the eigenspectrum of the Neural Tangent Kernel (NTK). The idea is that by removing these weights, we can obtain a smaller sub-network that retains the trainability and performance of the original dense network. Computing the full NTK and diagonalizing it to obtain the eigenspectrum is not feasible for real-world networks and datasets, as the memory requirements scale quadratically with the number of data points and parameters. Instead, we can use the trace of the NTK as a proxy for its eigenspectrum, since the trace is the sum of the eigenvalues: $Tr[\Theta(x,x)] = \sum_{i} \lambda_i$, where $\lambda_i$ are the eigenvalues of the NTK. Specifically, we focus on an upper bound of the trace as a cheap proxy to retain the eigenspectrum of the NTK. The intuition is that the more weights you remove, the more this upper bound will decrease. By removing only the weights that decrease the upper bound the least, we can maintain the eigenspectrum of the NTK to a greater extent. However, the scale of the eigenspectrum may not be guaranteed, as the more weights you remove, the more the eigenspectrum will be altered, regardless of the specific pruning strategy. Regarding instead the second part of the question, the motivation is given in the SNIP paper (Sec. 4.1, Equations 4 and 5). The idea is that we can approximate (with a derivative) the fact that we mask one weight $\theta_j$ at a time and observe how such masking affects a saliency function $R$. Such derivative needs to be taken with respect to the mask $c_j$, which corresponds to taking it with respect to $\theta_j$ but you would need to multiply the gradient by $\theta_j$ (i.e. $\frac{\partial R}{\partial c_j} = \frac{\partial R}{\partial \theta_j} \odot \theta_j$). The fact that we are using $\theta_j^2$ means that we need to change the saliency score to $\frac{\partial R}{\partial \theta_j^2} \odot \theta_j^2$. I'm not quite getting why exponentiating weights, my guess is that you're trying to connect the concept of "saliency" to a more probabilistic, softmax-like definition of saliency. For more details on what "saliency" means in Pruning-at-Initialization see the SynFlow paper.
Indeed! Since later in the derivation we combine the two terms of the upper bound, writing $$\sum{p=1}^P \sum{j=1}^m \left(\frac{v_p(\theta)}{\thetaj}\right)^2 = \sum{p=1}^P \sum_{j \in p} \left(\frac{v_p(\theta)}{\theta_j}\right)^2$$ is just a matter of preference as the terms for which $p_j \neq 1$ are zero, easing the notation a bit.
Hope this can help. Feel free to reach out at any time!
I see. Thanks for the kind explanation!
Regarding the first question, I'm sorry that I still got something not sure. As shown in the second equation in pp. 5, $$\sum^m _{j=1} \frac{\partial{R}}{\partial \theta^2_j} = | J^f_v(X)|^2F |J^f \theta |^2_F$$ (let's denote it as $R'$) which is the upper bound of the trace as in Eq 6. So this derivative sum $R'$ is the actual value we want to preserve during pruning. If I understand it correctly, why $R'$ is not the actual saliency function R here, and why we don't prune according to $\frac{\partial{R'}}{\theta_j}$ instead? Thanks in advance!
No problem, thank you for your question! Let me clarify: In the derivation of page 5, we indeed sum over all weights (j=1 to m) to obtain the upper bound $R'$. However, when considering each individual weight for pruning, using only the derivative with respect to the weight (i.e., $\frac{\partial R'}{\partial \theta_j^2}$) would not be correct. This is because weights have different magnitudes, which would skew the saliency score $S$ (if we would compute such saliency score with respect to a multiplicative mask $c$, which would be a vector of all 1s, then we wouldn't have any scale issue as all values are the same i.e., $\frac{\partial R}{\partial c_j} = \frac{\partial R}{\partial \theta_j} \odot \theta_j$)
Hoping it is a bit more clear now
Thanks for the kind explanation! I see that we should consider the effect of weights’ magnitudes, but I still got some concerns as I fear that $\partial R/\partial\thetaj^2$ has already considered the weight magnitude. I still fail to understand how $S{PX}$ is derived, and I'm not sure if I had expressed my questions clearly. $~$
To clarify the background: When using saliency score, we are targeting at preserving some certain function $R$ as described in the SynFlow paper. To avoid confusion, let's denote this $R$ in SynFlow as $F$ instead because the symbol $R$ has been used in this PX paper. Then, $\Delta F:=F(\theta_j=\theta_j)-F(\theta_j=0)=\frac{\partial F}{\partial \theta_j}\cdot\theta_j+o(\theta_j^2)$ Thus, we use $S(\theta) = \frac{\partial F}{\partial \theta_j} \cdot \theta_j$ as the score (the first order approximation of the change of $F$ when setting $\theta_j=0$). $~$
There are many choices of $F$, e.g. the loss function $L$ (in SNIP) or $\Delta L$ (in GraSP). While in this PX paper, if I understood it correctly, we are targeting at preserving the upper bound of the NTK trace, i.e. $F = \sum_j \partial R/\partial\theta_j^2$ (which is denoted as $R'$ in the comments above). Then, we can directly obtain the difference caused by removing $\theta_j$, i.e. the term regarding $\theta_j$, being $\Delta F = F(\theta_j=\theta_j)-F(\theta_j=0)=\frac{\partial R}{\partial \theta_j^2} $. Alternatively, we can still perform a first-order approximation on $\Delta F$ in the same way to obtain a score $S(\theta) = \frac{\partial F}{\partial\theta_j}\cdot\theta_j=\frac{\partial^2R}{\partial \theta_j \partial\theta_j^2}\cdot\theta_j$. Both of them take the magnitude of $\thetaj$ into account. However, neither of them equals $S{PX} = \frac{\partial R}{\partial\theta_j^2}\cdot\thetaj^2$, despite that $S{PX} = \Delta F \cdot \theta_j^2$. $~$
Alternatively, we can compute the scores w.r.t. a multiplicative mask $m$. Suppose $R(\theta \odot m)$ denotes the $R$ function with $\theta$ masked by $m$, then $\frac{\partial R(\theta \odot m)}{\partial\theta_j^2}=\frac{\partial R(\theta \odot m)}{\partial(m_j\theta_j)^2}\frac{\partial (m_j\theta_j)^2}{\partial\theta_j^2}=m_j\cdot\frac{\partial R(\theta)}{\partial\theta_j^2}$ and $F(\theta \odot m) =\sum_j\frac{\partial R(\theta \odot m)}{\partial\thetaj^2}=\sum{ \{ j;m_j=1 \} } \frac{\partial R}{\partial\theta_j^2}$ so setting $m_j$ from $1$ to $0$, results in a difference of $\Delta F=\frac{\partial R}{\partial \theta_j^2}$, which is the same as above. $~$
To sum up, the gap between the equations before Eq. 9 and the $S_{PX}$ definition in Eq. 9 confuses me and it will be great if it can be further clarified. Thanks in advance!
Hi again!
Sorry for the initial confusion, now I get clearly your concern and I thank you for having took time for a thorough explanation.
Considering the network as a graph made of input-output paths, the objective is to assign scores that reflect the contribution of each specific weight based on all the paths it belongs to. Specifically, given a weight $\theta{ij}$ in some layer $h$, connecting the $j$-th neuron of the previous layer ($h-1$) to the $i$-th neuron of the following layer ($h+1$), as sketched in the supplementary (Sec. B), we can rewrite the derivation of $S{PX}$ as:
Which means that the score assigned to $\theta_{ij}$ will not properly reflect the path values $vp^2$ as it will lack the term $\theta{ij}^2$ which we reintroduce with the multiplication in $S_{PX}$. Alternatively, if we mask $\thetaij$ with a 0-1 mask $m{ij}$ and derivate with respect to the mask, then we don't need to multiply it by $\theta_{ij}^2$.
Hope this is more clear now.
Sorry for the late reply and thank you so much for the explanation!
Thank you so much for this wonderful paper and it is really inspiring and valuable for me! While during my reading I encountered several issues that confuse me a bit, and I wonder if they can be clarified.
It seems that the goal is to preserve the large eigenvalues of the NTK after pruning by preserving its trace, which is allowed by preserving an upper bound of the trace indicated in Eq 6. To my understanding, the scale of the eigenvalues are actually not guaranteed. Furthermore, each parameter $\theta_j$ contributes to the upper bound by $\frac{\partial R(x, \theta, a)}{\partial \theta^2_j}$, so the parameters with least contribution should be first pruned. However, the actual saliency score (Eq 9.) further multiplies this contribution by $\theta^2_j$, which confuses me and I'm curious about the reason. In fact, if the purpose is to formulate a standard saliency-like form, we may consider $$\frac{\partial R}{\partial \theta^2} = \frac{\partial R}{\partial \exp (\theta^2)} \odot \exp (\theta^2) $$.
I fear that there is some issue with Eq 8. from [16]. In [16] ${(J\theta^v)} {p,j}$ is non-zero only if $pj=1$. Hence in Eq 8. I feel that it should be $\sum{j \in p}$ instead of ${\sum}^m_ {j=1}$. This also affects the later equations, but fortunately doesn't affect the correctness of the computation as the gradient of $h$ used to compute this term should goes to zero in this case.