Closed Dicer-Zz closed 1 year ago
@zizhaozhang @KingSpencer
Hi,
Thanks for your interest in our work and your questions!
For the first 4 questions, I believe the answers already exist in their threads, respectively:
@miss-rain I found in the official document that when JAX executes the first JAX command, it pre-allocates 90% of the available GPU memory. As described in the document, I can either disable the pre-allocation or reduce the pre-allocation ratio to run the ViT-Base model. As a result, I was able to run on my GPUs 8 RTX 3090 and 8 A5000.
solved by re-installing cudnn
Great insight! Actually we have not tried your suggested experiments, but it is definitely something worth trying. Regarding the "information leakage", I think we do make the assumption that we have a "well-pretrained" model, and we use the same pretrained model for all competitors, so the comparison is actually fair. Another thing I would like to highlight is that the idea of prompting is actually leveraging learned knowledge in the model, and trying to "instruct" the model to selectively use learned knowledge for coming tasks. Since large-scale pretrained model is prevalent these days, leveraging them is quite natural. On the other hand, thinking about the extreme case that the pretrained model is totally off (e.g. trained on a totally different dataset, though we will not do it in practice), L2P will probably fail if the backbone is frozen. Thus, it will be interesting to see how and when to adapt the model backbone as a future direction.
Same as 1.
For your own questions:
Hope all these help! Worst case, if you can not reproduce after all trials, feel free to send me your detailed configuration. I will try to find some time to help out.
Best, Zifeng
Thank you for your timely reply! My confusion centres on the second of the fourth questions:
My own reproduction of the code does not achieve the results shown in the left panel of Figure 3 in the paper, and I do not understand why catastrophic forgetting does not occur when such statistical results occur. And even without using the Optionally diversifying prompt-selection method, I can't get this statistic, same as in the #18 #24 This issue comes with detailed [statistics logs]
I‘m not concerned with why the loss became nan. I'm more interested in knowing how to reproduce the statistics of the kind of prompts selection that appears on the left-hand side of Figure 3 of the paper, because as mentioned above, both my own version of the reproduction, the question posed by the person in #18 and the detailed logs given in #24 are showing that the top-k prompts are first selected randomly when training the first task and then optimizing the representation of its keys, which causes the same top-k prompts to continue to be selected for all subsequent tasks when training. This inevitably leads to catastrophic forgetting.
question from #18:
According to the contents of the paper and my understanding, the prompts in the prompt pool should be selected evenly for each task. However, when I look at the prompt index and the Tensorboard histogram, it seems that only a few prompts are learned. Only the 3rd, 6th, 7th, and 9th prompts are used in the official code of the reproduce result.
logs from #24:
I0920 13:10:04.959301 140038901188352 logging_writer.py:53] [50] Histogram for 'histogram_0' = {[0, 0.3): 992, [0.3, 0.6): 0, [0.6, 0.9): 0, [0.9, 1.2): 992, [1.2, 1.5): 0, [1.5, 1.8): 0, [1.8, 2.1): 992, [2.1, 2.4): 0, [2.4, 2.7): 0, [2.7, 3]: 992} I0920 13:10:04.961827 140038901188352 logging_writer.py:53] [50] Histogram for 'histogram_1' = {[0, 0.3): 992, [0.3, 0.6): 0, [0.6, 0.9): 0, [0.9, 1.2): 992, [1.2, 1.5): 0, [1.5, 1.8): 0, [1.8, 2.1): 992, [2.1, 2.4): 0, [2.4, 2.7): 0, [2.7, 3]: 992} ...... delete histogram 2-8 for readability I0920 13:10:04.987567 140038901188352 logging_writer.py:53] [50] Histogram for 'histogram_8' = {[0, 0.3): 992, [0.3, 0.6): 0, [0.6, 0.9): 0, [0.9, 1.2): 992, [1.2, 1.5): 0, [1.5, 1.8): 0, [1.8, 2.1): 992, [2.1, 2.4): 0, [2.4, 2.7): 0, [2.7, 3]: 992} I0920 13:10:04.992606 140038901188352 logging_writer.py:53] [50] Histogram for 'histogram_9' = {[0, 0.3): 992, [0.3, 0.6): 0, [0.6, 0.9): 0, [0.9, 1.2): 992, [1.2, 1.5): 0, [1.5, 1.8): 0, [1.8, 2.1): 992, [2.1, 2.4): 0, [2.4, 2.7): 0, [2.7, 3]: 992}
Finally, I look forward to hearing from you again.
Hi,
Actually it's a great catch! I checked the configs carefully and think that there may be a typo in the configuration file cifar100_l2p.py
line 109 config.prompt_pool_param.batchwise_prompt = True
. This actually encourages prompt selection to be more concentrated and finally "collapse" to a limited set of prompts. By changing it to False
, it should make prompt selection in a more free way. I will update the codebase accordingly.
Best, Zifeng
Btw - I have been able to reproduce the results with my own implementation. I will vouch for the paper's reproducibility :)
@jamessealesmith Is that in pytorch? Would be plan to release the implementation? If so, I think it help help the community a lot.
Yes, the implementation is in pytorch! We will release it, planning before the end of 2022.
I already implemented L2P and DualPrompt in PyTorch. It was released on my repository (l2p-pytorch, dualprompt-pytorch) about 2-3 months ago. I hope that we can discuss it freely :)
@Lee-JH-KR Awesome work, Jaeho! I will go through your implementation and add your pytorch repos in the README of this repo soon:) Thanks again for your efforts.
I sincerely question the reproducibility of the code and the results of the paper, in this repo issue:
And, for myself:
I very sincerely hope that the author will answer the above questions.