kytimmylai / NoisyNN-PyTorch

non-official NoisyNN Implemnentation
Apache License 2.0
48 stars 2 forks source link

Does the implementation work? #1

Open faruknane opened 7 months ago

faruknane commented 7 months ago

Hi, does the implementation work? Is it tested yet?

kytimmylai commented 7 months ago

Hello,

The model should works like a common classifier, you can type "python model.py" to check it. It's based on the formula provided in the paper. But I can't guarantee they can achieve the same performance in paper without the source code or further information.

faruknane commented 7 months ago

@kytimmylai thank you for the quick response. I was wondering if you had a chance to test the model accuracy by retraining on a dataset?

Also, I wonder if this method works on not pretrained models. Meaning, should we activate adding the noise after the model completely learns the data once?

Here is a quote from the paper (NoisyNN: ): "The third constraint is to make the trained classifier get enough information about a specific image and correctly predict the corresponding label. For example, for an image X1 perturbed by another image X2, the classifier obtained dominant information from X1 so that it can predict the label Y1. However, if the perturbed image X2 is dominant, the classifier can hardly predict the correct label Y1 and is more likely to predict as Y2."

kytimmylai commented 7 months ago

I may try to reproduce the experiment, but I won't train it from initialization.

MarkWijkhuizen commented 7 months ago

I tried to reproduce the results on ResNet-34 with a subset of 100K ImageNet samples, 100 from each class. Training is performed with a batch size of 100, Adam optimizer with a starting learning rate of 5e-3 with cosine decay and common data augmentations for 100 epochs. The models with and without the linear transform noise with optimal Q gave roughly the same results, making me conclude I could not reproduce the results (~40% top1 accuracy)

Given the performance jump from 66.8% on vanilla ResNet-34 to 80.0% using the linear transform noise, you would expect an unambiguous increase in performance.

kytimmylai commented 7 months ago

@MarkWijkhuizen Not sure whether you employ pretraining or not. ViT comes with its official pretrained weights on ImageNet21k, available at https://github.com/google-research/vision_transformer. Mentioned above, we should use pretrain to meet its claim. However, I cannot ascertain the specific pretrained weights for ResNet.

MarkWijkhuizen commented 7 months ago

@kytimmylai thanks for pointing out the ImageNet21K pretraining.

Table 1/2 do not state any ImageNet21K pretraining, however table 3 does specify the ImageNet21K pretraining with the exact same ViT-B top1 accuracy of 83.33%. I therefore assume the ResNet results are also acquired using ImageNet21K pretraining.

I can not find ImageNet21K weights for ResNet, but I could find them for EfficientNet, will apply the same strategy with ImageNet21K initialized weights, will keep you posted.

kytimmylai commented 7 months ago

@MarkWijkhuizen Thanks, that's great. I appreciate your contribution.

faruknane commented 7 months ago

@MarkWijkhuizen thank you for the help! Also, do you think the code misses anything compared to the paper?

ChengShiest commented 7 months ago

Some code is miss aligned with the previous version of Explore Positive Noise in Deep Learning which is version of NoisyNN submitted to NIPS23. During rebuttal, the author response to Reviewer fTiY that:

2.2 When is the noise injection? The very beginning of the training and testing.

So, actually linear transform noise is also injected during testing, considering the imagenet test is doing without shuffle which means: all samples within a batch belong to the same category, It appears to be why only linear noise works, and only in the last layer.

kytimmylai commented 7 months ago

@ChengShiest I appreciate your reply and the details provided. However, I would like further clarification on your recent statement, specifically, "All samples within a batch belong to the same category." Could you please elaborate on this sentence?

TranThanh96 commented 7 months ago

why we need add noise in testing phase? I dont get it. So what if we bring it into product? do we need to add noise?

ChengShiest commented 7 months ago

@kytimmylai You can refer to detectron2's implement when training is false, the shape of first batch is

image: torch.Size([32, 3, 224, 224]) label: torch.Size([32])

the label itself is

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

So that the linear noise seems to take mean feature over batch samples

ChengShiest commented 7 months ago

@TranThanh96 First thing, inference should be batch-independent so i am not sure whether the high performance come form positive noise or just some trick.

MarkWijkhuizen commented 7 months ago

When training EfficientNetV2-B0 with ImageNet-21K pretrained weights for 100 epochs on a subset of 100K images, the model provide roughly identical results.

vanilla:  val_top1acc: 0.5766 - val_top5acc: 0.7875
noisy nn: val_top1acc: 0.5796 - val_top5acc: 0.7806

The implementation of the positive noise injection layer is shown below.

Without any scripts to reproduce the results or public pretrained weights I am skeptical about the authors claim of a 86M ViT-B model with a 95% ImageNet-1K Top 1 Accuracy.

# https://github.com/kytimmylai/NoisyNN-PyTorch/blob/main/model.py#L11 translated to Tensorflow
def optimal_quality_matrix(k):
    """r
    Optimal Quality matrix Q. Described in the eq (19) so that eps = QX, where X is the input. 
    Suppose 1_(kxk) is torch.ones
    """
    return tf.linalg.diag(tf.ones(shape=k)) * -k/(k+1) + tf.ones(shape=[k, k]) / (k+1)

class NoisyNNLayer(tf.keras.layers.Layer):
    def __init__(self, k, **kwargs):
        super(NoisyNNLayer, self).__init__(**kwargs)
        self.k = k

    def build(self, input_shape):
        self.Q = self.add_weight(
            shape=(self.k,self.k),
            initializer=tf.keras.initializers.Constant(optimal_quality_matrix(self.k)),
            trainable=False,
            name='optimal_quality_matrix',
        )
        super(NoisyNNLayer, self).build(input_shape)

    def call(self, x, training):
        if training:
            x = tf.transpose(x, perm=[0,3,1,2])
            x = self.Q@x + x
            x = tf.transpose(x, perm=[0,2,3,1])
        else:
            pass
        return x
kytimmylai commented 7 months ago

@MarkWijkhuizen Thank you for your contribution. The paper has further information from the response, and I've made a new one. Could you please re-evaluate this implementation? Just remove the conditional statement.

class NoisyNNLayer(tf.keras.layers.Layer):
    def __init__(self, k, **kwargs):
        super(NoisyNNLayer, self).__init__(**kwargs)
        self.k = k

    def build(self, input_shape):
        self.Q = self.add_weight(
            shape=(self.k,self.k),
            initializer=tf.keras.initializers.Constant(optimal_quality_matrix(self.k)),
            trainable=False,
            name='optimal_quality_matrix',
        )
        super(NoisyNNLayer, self).build(input_shape)

    def call(self, x, training):
        x = tf.transpose(x, perm=[0,3,1,2])
        x = self.Q@x + x
        x = tf.transpose(x, perm=[0,2,3,1])
        return x
helvince commented 7 months ago

Hello everyone,

I myself had previously spend time on NoisyNN. Great to see the efforts on this repo. The paper and its claims are intriguing. Effectively they claim to apply one simple trick to boost a small ViT performance to SOTA. Bold to say the least, and, as done by the reviewers of nips, should be evaluated cautiously. Nevertheless, the authors have convinced some of the reviewers and represent three different American universities (scandalous if this were to be fraudulent).

I too began to experiment from their paper and the openreview, but came to the conclusion that some crucial details were missing. I contacted the first author two weeks ago, who responded that he would share the code if the co-authors gave permission. As it has been a while I am going to send a follow up mail to hopefully get some insights.

From an implementation standpoint, I think the code here is a good guesstimate. You could make it more efficient (instead of a matrix multiplication you can use a weighted permutation (for the linear variant) or weighted sum (for the optimal variant). Though this is nitpicking and does not improve/change the accuracy.

I too noticed the authors comment that noise is also included during inference. For the linear quality matrix this sounds invalid (you should not take another sample from the validation/test data, and what arbitrary sample from the training data should you use...). For the optimal quality matrix it could be possible (either as a running mean, like done in batch norm, or by computing a sum over the complete training data with finalized, fixed weights).

Additionally, I looked into the derivation of the optimal quality matrix and noticed a possible improvement. Whether their solution was deliberately restricted, empirically calibrated or an additional, unmentioned constraint is at play is to be seen.

I'll return with an update if I hear back from the authors and will keep an eye on everyone's work here ;)

kytimmylai commented 7 months ago

@helvince Your contribution is much appreciated, and I share your anticipation for releasing the original code. It would be valuable for further analysis and collaboration. I implemented the code because the code wasn't published and won't usually show again in such cases.

The optimization perspective is interesting. I will investigate it if the result is true 👍

romanvelichkin commented 7 months ago

While other SOTA solutions continue to push the number of parameters into the billions, such a simple solution with such a huge jump in accuracy looks like magic. Thank you guys for all the effort you put into reproducing this work.

helvince commented 6 months ago

I did not receive another response from the authors. Their best accuracy score has been removed on Papers with code and only the model with linear noise is still included (84.8% accuracy, 272th).

Another open review from ICLR has been made public: https://openreview.net/forum?id=zIrpuifCJW. The corresponding paper submission still claims 95% accuracy and has some new content.

The GitHub page of the authors does not mention any of these developments. It does have some active contributions recently, though it still does not contain any code.

kytimmylai commented 6 months ago

@helvince Thank you for the follow-up. I also noticed they withdrew the 95% result on Papers with Code. I may not proceed with reproduction until more responses or concrete evidence supports their results' validity. If they prove their result, i.e., publish their code(looks like they gonna do), then I will not need to reproduce it also.

Thank all mates in this issue for your opinion. It's a great experience.

tsunghan-wu commented 2 months ago

Hi, thank you @kytimmylai for sharing the reproduction code. I apologize for bringing this thread up again, but it has been a couple of months since the last update and I was wondering if there has been any progress on this issue. It seems that the first author is still promoting this paper on their homepage (indicating that this paper can be less likely to be "an accident"), and many people are intrigued by the impressive 95% ImageNet accuracy, but unfortunately, nobody has been able to reproduce these results...

dbolya commented 2 months ago

Hello, I am also interested in this. But honestly my confidence that the original authors aren't either lying or have some serious bug in their eval code is below zero at this point. I think the aim for reproduction should be to attempt to reproduce which bug / leak of validation labels the original authors used (potentially unknowingly) to get their results.

The addition of the noise "during eval" is the most suspect to me. Especially since their method relies on pulling content from other images as noise (i.e., prime suspect for a validation leak). And as @ChengShiest said, the ImageNet val set is sorted by class label. If you give the model 3 adjacent images in the val set, they'll all likely be the same class. Now if you ensemble predictions from these 3 images you'll be much more likely to get the correct prediction as they all have the same label. It could be that their noise is doing that, or something else could be wrong. I'm interested in hearing everyone's thoughts.

Is anyone able to open this onedrive link that the authors shared for their ICLR rebuttal? I'm not able to open it: https://1drv.ms/f/s!At0Jf7x3GkSgaumH6XgghGvuWYY?e=jAXqk5

kytimmylai commented 2 months ago

Hi @tsunghan-wu , thank you for following up on this issue. As @dbolya mentioned, the repository is intended to validate their method. I would appreciate it if the author could point out any errors I may have overlooked.

We are focusing on the best result, which is in section 4.4, Optimal Quality Matrix. It aims to maximize the entropy change of equation (12). Line 62 of noisy_vit.py corresponds to "X + QX" in equation (12), and the matrix is obtained in equation (19). They claim that the entropy change is determined by the number of data samples, but k should be the shape of the feature map in the 3rd stage for X + QX to be established. However, if that's the case, k is irrelevant to the number of data samples. What is the relationship between a 2D untrainable matrix and the number of data samples?

Note that the equation numbers are based on the order in version 2 here. Please correct me if I have misunderstood something.

btw, I am also unable to access the link above.