google-research / l2p

Learning to Prompt (L2P) for Continual Learning @ CVPR22 and DualPrompt: Complementary Prompting for Rehearsal-free Continual Learning @ ECCV22
https://arxiv.org/pdf/2112.08654.pdf
Apache License 2.0
416 stars 42 forks source link

Questions about the reproducibility of the code and the results of the paper #25

Closed Dicer-Zz closed 1 year ago

Dicer-Zz commented 2 years ago

I sincerely question the reproducibility of the code and the results of the paper, in this repo issue:

  1. I have seen people using the same V100 GPU as the author in time, but unable to run through the code and experiencing OOM errors. No reply was forthcoming #1 #20
  2. My own reproduction of the code does not achieve the results shown in the left panel of Figure 3 in the paper, and I do not understand why catastrophic forgetting does not occur when such statistical results occur. And even without using the Optionally diversifying prompt-selection method, I can't get this statistic, same as in the #18 #24 This issue comes with detailed statistics logs. By looking at the Histogram records, we can see that only four prompts were selected and that all tasks share these prompts. I think this is inevitably going to cause catastrophic forgetting.
  3. The use of pre-trained ViT may have caused an information leak #11.
  4. The given requirement.txt does not directly install the required runtime environment, and even if it does, it will only run on the CPU #1.

And, for myself:

  1. This code is really hard to run on my RTX 3090 GPU, and even after a lot of effort and without any error reporting, the program is stuck at training step 5 of the first task.
  2. I have not seen anyone in the issues who has successfully reproduced the results.

I very sincerely hope that the author will answer the above questions.

Dicer-Zz commented 2 years ago

@zizhaozhang @KingSpencer

KingSpencer commented 2 years ago

Hi,

Thanks for your interest in our work and your questions!

For the first 4 questions, I believe the answers already exist in their threads, respectively:

  1. @miss-rain I found in the official document that when JAX executes the first JAX command, it pre-allocates 90% of the available GPU memory. As described in the document, I can either disable the pre-allocation or reduce the pre-allocation ratio to run the ViT-Base model. As a result, I was able to run on my GPUs 8 RTX 3090 and 8 A5000.

  2. solved by re-installing cudnn

  3. Great insight! Actually we have not tried your suggested experiments, but it is definitely something worth trying. Regarding the "information leakage", I think we do make the assumption that we have a "well-pretrained" model, and we use the same pretrained model for all competitors, so the comparison is actually fair. Another thing I would like to highlight is that the idea of prompting is actually leveraging learned knowledge in the model, and trying to "instruct" the model to selectively use learned knowledge for coming tasks. Since large-scale pretrained model is prevalent these days, leveraging them is quite natural. On the other hand, thinking about the extreme case that the pretrained model is totally off (e.g. trained on a totally different dataset, though we will not do it in practice), L2P will probably fail if the backbone is frozen. Thus, it will be interesting to see how and when to adapt the model backbone as a future direction.

  4. Same as 1.

For your own questions:

  1. Please carefully check if any one of the solutions mentioned above can help you out. There might be some environment mismatch.
  2. People normally will not mention if they have successfully reproduced the codebase in "issues". As far as I know, some researchers have already reproduced the result with negligible difference, and we have exchanged emails privately to discuss about the results.

Hope all these help! Worst case, if you can not reproduce after all trials, feel free to send me your detailed configuration. I will try to find some time to help out.

Best, Zifeng

Dicer-Zz commented 2 years ago

Thank you for your timely reply! My confusion centres on the second of the fourth questions:

My own reproduction of the code does not achieve the results shown in the left panel of Figure 3 in the paper, and I do not understand why catastrophic forgetting does not occur when such statistical results occur. And even without using the Optionally diversifying prompt-selection method, I can't get this statistic, same as in the #18 #24 This issue comes with detailed [statistics logs]

I‘m not concerned with why the loss became nan. I'm more interested in knowing how to reproduce the statistics of the kind of prompts selection that appears on the left-hand side of Figure 3 of the paper, because as mentioned above, both my own version of the reproduction, the question posed by the person in #18 and the detailed logs given in #24 are showing that the top-k prompts are first selected randomly when training the first task and then optimizing the representation of its keys, which causes the same top-k prompts to continue to be selected for all subsequent tasks when training. This inevitably leads to catastrophic forgetting.

question from #18:

According to the contents of the paper and my understanding, the prompts in the prompt pool should be selected evenly for each task. However, when I look at the prompt index and the Tensorboard histogram, it seems that only a few prompts are learned. Only the 3rd, 6th, 7th, and 9th prompts are used in the official code of the reproduce result.

logs from #24:

I0920 13:10:04.959301 140038901188352 logging_writer.py:53] [50] Histogram for 'histogram_0' = {[0, 0.3): 992, [0.3, 0.6): 0, [0.6, 0.9): 0, [0.9, 1.2): 992, [1.2, 1.5): 0, [1.5, 1.8): 0, [1.8, 2.1): 992, [2.1, 2.4): 0, [2.4, 2.7): 0, [2.7, 3]: 992} I0920 13:10:04.961827 140038901188352 logging_writer.py:53] [50] Histogram for 'histogram_1' = {[0, 0.3): 992, [0.3, 0.6): 0, [0.6, 0.9): 0, [0.9, 1.2): 992, [1.2, 1.5): 0, [1.5, 1.8): 0, [1.8, 2.1): 992, [2.1, 2.4): 0, [2.4, 2.7): 0, [2.7, 3]: 992} ...... delete histogram 2-8 for readability I0920 13:10:04.987567 140038901188352 logging_writer.py:53] [50] Histogram for 'histogram_8' = {[0, 0.3): 992, [0.3, 0.6): 0, [0.6, 0.9): 0, [0.9, 1.2): 992, [1.2, 1.5): 0, [1.5, 1.8): 0, [1.8, 2.1): 992, [2.1, 2.4): 0, [2.4, 2.7): 0, [2.7, 3]: 992} I0920 13:10:04.992606 140038901188352 logging_writer.py:53] [50] Histogram for 'histogram_9' = {[0, 0.3): 992, [0.3, 0.6): 0, [0.6, 0.9): 0, [0.9, 1.2): 992, [1.2, 1.5): 0, [1.5, 1.8): 0, [1.8, 2.1): 992, [2.1, 2.4): 0, [2.4, 2.7): 0, [2.7, 3]: 992}

Finally, I look forward to hearing from you again.

KingSpencer commented 2 years ago

Hi,

Actually it's a great catch! I checked the configs carefully and think that there may be a typo in the configuration file cifar100_l2p.py line 109 config.prompt_pool_param.batchwise_prompt = True. This actually encourages prompt selection to be more concentrated and finally "collapse" to a limited set of prompts. By changing it to False, it should make prompt selection in a more free way. I will update the codebase accordingly.

Best, Zifeng

jamessealesmith commented 2 years ago

Btw - I have been able to reproduce the results with my own implementation. I will vouch for the paper's reproducibility :)

zizhaozhang commented 2 years ago

@jamessealesmith Is that in pytorch? Would be plan to release the implementation? If so, I think it help help the community a lot.

jamessealesmith commented 2 years ago

Yes, the implementation is in pytorch! We will release it, planning before the end of 2022.

JH-LEE-KR commented 2 years ago

I already implemented L2P and DualPrompt in PyTorch. It was released on my repository (l2p-pytorch, dualprompt-pytorch) about 2-3 months ago. I hope that we can discuss it freely :)

KingSpencer commented 2 years ago

@Lee-JH-KR Awesome work, Jaeho! I will go through your implementation and add your pytorch repos in the README of this repo soon:) Thanks again for your efforts.