iamwangyabin / SPrompts_eval

Evaluation code for S-Prompts
2 stars 1 forks source link

Model loading mismatch between S-Prompts model and your pre-trained model #1

Closed laitifranz closed 9 months ago

laitifranz commented 10 months ago

Hi,

I've been exploring the S-Prompts repository and am impressed with your work! However, I've encountered an issue when trying to load my trained model. Here's a detailed description:

Steps to Reproduce

  1. Trained a model from scratch using the S-Prompts repository on the CDDB benchmark, with modifications only to the config files for correct paths. I will refer to this as my model.
  2. Successfully loaded and used the pre-trained model provided in this repository.
  3. Attempted to evaluate my model using this repository, which resulted in module loading issues.

Expected Behavior

I expected my model to load successfully and be evaluable just like the pre-trained model.

Actual Behavior

There seems to be a mismatch in module names between my model and the pre-trained model, leading to loading issues. Figuring out what can be the problem, I returned the list of key mismatches:

Key mismatch: classifier_pool.0.ctx vs prompt_pool.0.ctx
Key mismatch: classifier_pool.0.token_prefix vs prompt_pool.0.token_prefix
Key mismatch: classifier_pool.0.token_suffix vs prompt_pool.0.token_suffix
Key mismatch: classifier_pool.1.ctx vs prompt_pool.1.ctx
Key mismatch: classifier_pool.1.token_prefix vs prompt_pool.1.token_prefix
Key mismatch: classifier_pool.1.token_suffix vs prompt_pool.1.token_suffix
Key mismatch: classifier_pool.2.ctx vs prompt_pool.2.ctx
Key mismatch: classifier_pool.2.token_prefix vs prompt_pool.2.token_prefix
Key mismatch: classifier_pool.2.token_suffix vs prompt_pool.2.token_suffix
Key mismatch: classifier_pool.3.ctx vs prompt_pool.3.ctx
Key mismatch: classifier_pool.3.token_prefix vs prompt_pool.3.token_prefix
Key mismatch: classifier_pool.3.token_suffix vs prompt_pool.3.token_suffix
Key mismatch: classifier_pool.4.ctx vs prompt_pool.4.ctx
Key mismatch: classifier_pool.4.token_prefix vs prompt_pool.4.token_prefix
Key mismatch: classifier_pool.4.token_suffix vs prompt_pool.4.token_suffix
Key mismatch: classifier_pool.5.ctx vs prompt_pool.5.ctx
Key mismatch: classifier_pool.5.token_prefix vs prompt_pool.5.token_prefix
Key mismatch: classifier_pool.5.token_suffix vs prompt_pool.5.token_suffix
Key mismatch: classifier_pool.6.ctx vs prompt_pool.6.ctx
Key mismatch: classifier_pool.6.token_prefix vs prompt_pool.6.token_prefix
Key mismatch: classifier_pool.6.token_suffix vs prompt_pool.6.token_suffix
Key mismatch: prompt_pool.0.weight vs instance_prompt.0.weight
Key mismatch: prompt_pool.1.weight vs instance_prompt.1.weight
Key mismatch: prompt_pool.2.weight vs instance_prompt.2.weight
Key mismatch: prompt_pool.3.weight vs instance_prompt.3.weight
Key mismatch: prompt_pool.4.weight vs instance_prompt.4.weight
Key mismatch: prompt_pool.5.weight vs instance_prompt.5.weight
Key mismatch: prompt_pool.6.weight vs instance_prompt.6.weight

Analysis

Both models have the same number of parameters, but the layer keys do not match. I've traced back the module names in the pre-trained model to the code in the repository and found potential discrepancies. It seems like two parameters might have been renamed:

More in-depth, given the pre-trained model structure and the training repo code, I got

  ...
  (prompt_pool): ModuleList(
    (0-6): 7 x PromptLearner()
  )

https://github.com/iamwangyabin/S-Prompts/blob/14b5902da6b3bafde7087b15704658becfb73491/models/slinet.py#L24-L27 and

  (instance_prompt): ModuleList(
    (0-6): 7 x Linear(in_features=768, out_features=10, bias=False)
  )

https://github.com/iamwangyabin/S-Prompts/blob/14b5902da6b3bafde7087b15704658becfb73491/models/slinet.py#L44-L47

where can observe a clear difference in layer names. It seems that your pre-trained model has been built with different classes w.r.t. the original S-Prompts repo.

Questions

Thank you to anyone who can help me!