Open Nonac opened 1 week ago
Thank you for your interest in our work!
After reviewing the code, I realized the issue likely arose during the process of cleaning and preparing the repository for upload (in the original I used $x^2$, but with an additional step clarified in the following)
The focus of our work was on ReLU networks with batch normalization layers (ResNets, mainly). I recall that incorporating $x^2$ in such networks can be problematic, as it complicates compatibility with batch normalization statistics. Specifically, one would need to update the batch normalization statistics with $x^2$ before proceeding. However, since we apply cached activations from the original network (which are non-negative due to the use of ReLUs) the results we observed were effectively the same as if we had used $x^2$.
First, I’d like to say I really appreciate the work you’ve done here! The pruning method is well-structured and provides a solid framework for model optimization. As I was going through the code, I noticed something in the handling of jvf_model during the forward pass, and I wanted to ask about your design decision.
In the paper, it mentions that the input data should be squared before being passed into the jvf_model. Specifically, the text states:
In the current code, the data is used directly without squaring it:
https://github.com/iurada/px-ntk-pruning/blob/82ce2bf5cedb476c063a2308ead03a453f7361fd/lib/pruners.py#L463
I was wondering if there was a particular reason for not squaring the data in this part of the implementation. If we were to square the data before the forward pass, like this:
what would be the impact on the model’s performance or pruning results? Since this is a data-driven method, I would expect that any change to the data would have a significant impact on the overall performance. I’m curious to hear your thoughts on this.
Thanks again for your great work, and I look forward to your insights!