iurada / px-ntk-pruning

Official repository of our work "Finding Lottery Tickets in Vision Models via Data-driven Spectral Foresight Pruning" accepted at CVPR 2024
https://iurada.github.io/PX
18 stars 4 forks source link

Question: Squaring Data Before Passing to jvf_model in PX Pruning Method #6

Open Nonac opened 1 week ago

Nonac commented 1 week ago

First, I’d like to say I really appreciate the work you’ve done here! The pruning method is well-structured and provides a solid framework for model optimization. As I was going through the code, I noticed something in the handling of jvf_model during the forward pass, and I wanted to ask about your design decision.

In the paper, it mentions that the input data should be squared before being passed into the jvf_model. Specifically, the text states:

“Instead, it takes the squared data as input, the parameters are all one, and the activations status is an exact copy of that of f.”

In the current code, the data is used directly without squaring it:

https://github.com/iurada/px-ntk-pruning/blob/82ce2bf5cedb476c063a2308ead03a453f7361fd/lib/pruners.py#L463

I was wondering if there was a particular reason for not squaring the data in this part of the implementation. If we were to square the data before the forward pass, like this:

data_squared = data.pow(2)
z1 = jvf_model(data_squared)

what would be the impact on the model’s performance or pruning results? Since this is a data-driven method, I would expect that any change to the data would have a significant impact on the overall performance. I’m curious to hear your thoughts on this.

Thanks again for your great work, and I look forward to your insights!

iurada commented 6 days ago

Thank you for your interest in our work!

After reviewing the code, I realized the issue likely arose during the process of cleaning and preparing the repository for upload (in the original I used $x^2$, but with an additional step clarified in the following)

The focus of our work was on ReLU networks with batch normalization layers (ResNets, mainly). I recall that incorporating $x^2$ in such networks can be problematic, as it complicates compatibility with batch normalization statistics. Specifically, one would need to update the batch normalization statistics with $x^2$ before proceeding. However, since we apply cached activations from the original network (which are non-negative due to the use of ReLUs) the results we observed were effectively the same as if we had used $x^2$.