mlfoundations / open_flamingo

An open-source framework for training large multimodal models.
MIT License
3.74k stars 284 forks source link

Will it be possible to train the whole model using the 7B LLM #274

Closed ElegantLin closed 1 year ago

ElegantLin commented 1 year ago

Dear authors,

Thanks for your great work. I wonder whether it is possible to fine-tune the whole model whose LLM is 7B using one 80G GPU if I use some settings, like FSDP, bfp16, etc.

Thanks!

anas-awadalla commented 1 year ago

Hello! You wouldn't be able to take advantage of FSDP (sharding parameters etc.) as you are only using a single GPU. Maybe you can use FSDP to offload parameters to CPU but we don't support this so you would have to modify the code to do so. If by 'fine-tune the whole model' you mean the cross attention weights (which are the only trainable parameters for Flamingo) then the entire model will fit on an 80GB gpu :).

ElegantLin commented 1 year ago

Thanks for your quick reply. I am sorry that I did not make myself clear.

The largest GPU I have is 80G and I can use more than 1 GPU. I am trying to finetune all the parameters (7B) even using bf16 but it throws out of memory when the model is doing BP. I have checked this issue. I think I can fine-tune all the parameters if I decrease the LLM to a 3B one. Could you please give me some suggestions if I want to fine-tune all the parameters using 7B LLM?

Thanks!

anas-awadalla commented 1 year ago

Got it! Ok since you do have multiple GPUs you should use FSDP. I wouldn't train using pure bf16 as we didn't have success with that. You should go with amp_bf16. If you are running into issues with that then please share the command you are running.

ElegantLin commented 1 year ago

Thanks for your quick reply. The same as #232, I met the same dimension mismatch issue. I also notice that Gao has started PR #261 to fix it. However, she made a big refactor of the code. I can use her code, but are there more minor modifications if I want to use the fsdp based on the current main branch because maybe the main branch is more stable? Or would you rather suggest I use themllm branch?

Thanks for your suggestions!

ElegantLin commented 1 year ago

Hi @anas-awadalla, Thanks for your quick reply. I am still trying to fine-tune the whole parameters. I tried the 3B model this time because I think it fits my 80G GPU. I just set model.requires_grad_(True). I trained the 3B model on one 80G GPU. However, I got the RuntimeError that

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by (1) passing the keyword argument find_unused_parameters=True to torch.nn.parallel.DistributedDataParallel; (2) making sure all forward function outputs participate in calculating loss. If you already have done the above two steps, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's forward function. Please include the loss function and the structure of the return value of forward of your module when reporting this issue (e.g. list, dict, iterable). Parameter indices which did not receive grad for rank 0: 2 294 295

I followed the #253 to comment the with torch.no_grad(). The branch I used is the current main branch. I think maybe I missed some parameters to update. Should I add another loss in the vision encoder?

Really appreciate your patience and help!

anas-awadalla commented 1 year ago

Hmm yeah this error indicates that there is a trainable component not being used. One thing that comes to mind is maybe you have some samples without images? Also you do intend for the vision encoder to be unfrozen?

Can you print the name of the trainable parameters in the model and share them here?

ElegantLin commented 1 year ago

Hi @anas-awadalla, I really appreciate your help. I set this line and modify model parameters gradient settings as below. I also add an indices-parameter output for loop to tell me the model parameter indices.

model.requires_grad_(True)
# assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0

# Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings
# model.perceiver.requires_grad_(True)
# model.lang_encoder.gated_cross_attn_layers.requires_grad_(True)
# if not freeze_lm_embeddings:
#    model.lang_encoder.get_input_embeddings().requires_grad_(True)

i = 0
for name, param in model.named_parameters():
    if param.requires_grad:
        print(i, name)
        i += 1

I also uncomment this with no grad() line in the function _encode_vision_x. The parameter indices which did not receive grad are 0, 294, 295, which could be vision_encoder.class_embedding, vision_encoder.ln_post.weight and vision_encoder.ln_post.bias.

Could you please help me?

Thanks!

The original trainable parameters are as follows.

0 vision_encoder.class_embedding
1 vision_encoder.positional_embedding
2 vision_encoder.proj
3 vision_encoder.conv1.weight
4 vision_encoder.ln_pre.weight
5 vision_encoder.ln_pre.bias
6 vision_encoder.transformer.resblocks.0.ln_1.weight
7 vision_encoder.transformer.resblocks.0.ln_1.bias
8 vision_encoder.transformer.resblocks.0.attn.in_proj_weight
9 vision_encoder.transformer.resblocks.0.attn.in_proj_bias
10 vision_encoder.transformer.resblocks.0.attn.out_proj.weight
11 vision_encoder.transformer.resblocks.0.attn.out_proj.bias
12 vision_encoder.transformer.resblocks.0.ln_2.weight
13 vision_encoder.transformer.resblocks.0.ln_2.bias
14 vision_encoder.transformer.resblocks.0.mlp.c_fc.weight
15 vision_encoder.transformer.resblocks.0.mlp.c_fc.bias
16 vision_encoder.transformer.resblocks.0.mlp.c_proj.weight
17 vision_encoder.transformer.resblocks.0.mlp.c_proj.bias
18 vision_encoder.transformer.resblocks.1.ln_1.weight
19 vision_encoder.transformer.resblocks.1.ln_1.bias
20 vision_encoder.transformer.resblocks.1.attn.in_proj_weight
21 vision_encoder.transformer.resblocks.1.attn.in_proj_bias
22 vision_encoder.transformer.resblocks.1.attn.out_proj.weight
23 vision_encoder.transformer.resblocks.1.attn.out_proj.bias
24 vision_encoder.transformer.resblocks.1.ln_2.weight
25 vision_encoder.transformer.resblocks.1.ln_2.bias
26 vision_encoder.transformer.resblocks.1.mlp.c_fc.weight
27 vision_encoder.transformer.resblocks.1.mlp.c_fc.bias
28 vision_encoder.transformer.resblocks.1.mlp.c_proj.weight
29 vision_encoder.transformer.resblocks.1.mlp.c_proj.bias
30 vision_encoder.transformer.resblocks.2.ln_1.weight
31 vision_encoder.transformer.resblocks.2.ln_1.bias
32 vision_encoder.transformer.resblocks.2.attn.in_proj_weight
33 vision_encoder.transformer.resblocks.2.attn.in_proj_bias
34 vision_encoder.transformer.resblocks.2.attn.out_proj.weight
35 vision_encoder.transformer.resblocks.2.attn.out_proj.bias
36 vision_encoder.transformer.resblocks.2.ln_2.weight
37 vision_encoder.transformer.resblocks.2.ln_2.bias
38 vision_encoder.transformer.resblocks.2.mlp.c_fc.weight
39 vision_encoder.transformer.resblocks.2.mlp.c_fc.bias
40 vision_encoder.transformer.resblocks.2.mlp.c_proj.weight
41 vision_encoder.transformer.resblocks.2.mlp.c_proj.bias
42 vision_encoder.transformer.resblocks.3.ln_1.weight
43 vision_encoder.transformer.resblocks.3.ln_1.bias
44 vision_encoder.transformer.resblocks.3.attn.in_proj_weight
45 vision_encoder.transformer.resblocks.3.attn.in_proj_bias
46 vision_encoder.transformer.resblocks.3.attn.out_proj.weight
47 vision_encoder.transformer.resblocks.3.attn.out_proj.bias
48 vision_encoder.transformer.resblocks.3.ln_2.weight
49 vision_encoder.transformer.resblocks.3.ln_2.bias
50 vision_encoder.transformer.resblocks.3.mlp.c_fc.weight
51 vision_encoder.transformer.resblocks.3.mlp.c_fc.bias
52 vision_encoder.transformer.resblocks.3.mlp.c_proj.weight
53 vision_encoder.transformer.resblocks.3.mlp.c_proj.bias
54 vision_encoder.transformer.resblocks.4.ln_1.weight
55 vision_encoder.transformer.resblocks.4.ln_1.bias
56 vision_encoder.transformer.resblocks.4.attn.in_proj_weight
57 vision_encoder.transformer.resblocks.4.attn.in_proj_bias
58 vision_encoder.transformer.resblocks.4.attn.out_proj.weight
59 vision_encoder.transformer.resblocks.4.attn.out_proj.bias
60 vision_encoder.transformer.resblocks.4.ln_2.weight
61 vision_encoder.transformer.resblocks.4.ln_2.bias
62 vision_encoder.transformer.resblocks.4.mlp.c_fc.weight
63 vision_encoder.transformer.resblocks.4.mlp.c_fc.bias
64 vision_encoder.transformer.resblocks.4.mlp.c_proj.weight
65 vision_encoder.transformer.resblocks.4.mlp.c_proj.bias
66 vision_encoder.transformer.resblocks.5.ln_1.weight
67 vision_encoder.transformer.resblocks.5.ln_1.bias
68 vision_encoder.transformer.resblocks.5.attn.in_proj_weight
69 vision_encoder.transformer.resblocks.5.attn.in_proj_bias
70 vision_encoder.transformer.resblocks.5.attn.out_proj.weight
71 vision_encoder.transformer.resblocks.5.attn.out_proj.bias
72 vision_encoder.transformer.resblocks.5.ln_2.weight
73 vision_encoder.transformer.resblocks.5.ln_2.bias
74 vision_encoder.transformer.resblocks.5.mlp.c_fc.weight
75 vision_encoder.transformer.resblocks.5.mlp.c_fc.bias
76 vision_encoder.transformer.resblocks.5.mlp.c_proj.weight
77 vision_encoder.transformer.resblocks.5.mlp.c_proj.bias
78 vision_encoder.transformer.resblocks.6.ln_1.weight
79 vision_encoder.transformer.resblocks.6.ln_1.bias
80 vision_encoder.transformer.resblocks.6.attn.in_proj_weight
81 vision_encoder.transformer.resblocks.6.attn.in_proj_bias
82 vision_encoder.transformer.resblocks.6.attn.out_proj.weight
83 vision_encoder.transformer.resblocks.6.attn.out_proj.bias
84 vision_encoder.transformer.resblocks.6.ln_2.weight
85 vision_encoder.transformer.resblocks.6.ln_2.bias
86 vision_encoder.transformer.resblocks.6.mlp.c_fc.weight
87 vision_encoder.transformer.resblocks.6.mlp.c_fc.bias
88 vision_encoder.transformer.resblocks.6.mlp.c_proj.weight
89 vision_encoder.transformer.resblocks.6.mlp.c_proj.bias
90 vision_encoder.transformer.resblocks.7.ln_1.weight
91 vision_encoder.transformer.resblocks.7.ln_1.bias
92 vision_encoder.transformer.resblocks.7.attn.in_proj_weight
93 vision_encoder.transformer.resblocks.7.attn.in_proj_bias
94 vision_encoder.transformer.resblocks.7.attn.out_proj.weight
95 vision_encoder.transformer.resblocks.7.attn.out_proj.bias
96 vision_encoder.transformer.resblocks.7.ln_2.weight
97 vision_encoder.transformer.resblocks.7.ln_2.bias
98 vision_encoder.transformer.resblocks.7.mlp.c_fc.weight
99 vision_encoder.transformer.resblocks.7.mlp.c_fc.bias
100 vision_encoder.transformer.resblocks.7.mlp.c_proj.weight
101 vision_encoder.transformer.resblocks.7.mlp.c_proj.bias
102 vision_encoder.transformer.resblocks.8.ln_1.weight
103 vision_encoder.transformer.resblocks.8.ln_1.bias
104 vision_encoder.transformer.resblocks.8.attn.in_proj_weight
105 vision_encoder.transformer.resblocks.8.attn.in_proj_bias
106 vision_encoder.transformer.resblocks.8.attn.out_proj.weight
107 vision_encoder.transformer.resblocks.8.attn.out_proj.bias
108 vision_encoder.transformer.resblocks.8.ln_2.weight
109 vision_encoder.transformer.resblocks.8.ln_2.bias
110 vision_encoder.transformer.resblocks.8.mlp.c_fc.weight
111 vision_encoder.transformer.resblocks.8.mlp.c_fc.bias
112 vision_encoder.transformer.resblocks.8.mlp.c_proj.weight
113 vision_encoder.transformer.resblocks.8.mlp.c_proj.bias
114 vision_encoder.transformer.resblocks.9.ln_1.weight
115 vision_encoder.transformer.resblocks.9.ln_1.bias
116 vision_encoder.transformer.resblocks.9.attn.in_proj_weight
117 vision_encoder.transformer.resblocks.9.attn.in_proj_bias
118 vision_encoder.transformer.resblocks.9.attn.out_proj.weight
119 vision_encoder.transformer.resblocks.9.attn.out_proj.bias
120 vision_encoder.transformer.resblocks.9.ln_2.weight
121 vision_encoder.transformer.resblocks.9.ln_2.bias
122 vision_encoder.transformer.resblocks.9.mlp.c_fc.weight
123 vision_encoder.transformer.resblocks.9.mlp.c_fc.bias
124 vision_encoder.transformer.resblocks.9.mlp.c_proj.weight
125 vision_encoder.transformer.resblocks.9.mlp.c_proj.bias
126 vision_encoder.transformer.resblocks.10.ln_1.weight
127 vision_encoder.transformer.resblocks.10.ln_1.bias
128 vision_encoder.transformer.resblocks.10.attn.in_proj_weight
129 vision_encoder.transformer.resblocks.10.attn.in_proj_bias
130 vision_encoder.transformer.resblocks.10.attn.out_proj.weight
131 vision_encoder.transformer.resblocks.10.attn.out_proj.bias
132 vision_encoder.transformer.resblocks.10.ln_2.weight
133 vision_encoder.transformer.resblocks.10.ln_2.bias
134 vision_encoder.transformer.resblocks.10.mlp.c_fc.weight
135 vision_encoder.transformer.resblocks.10.mlp.c_fc.bias
136 vision_encoder.transformer.resblocks.10.mlp.c_proj.weight
137 vision_encoder.transformer.resblocks.10.mlp.c_proj.bias
138 vision_encoder.transformer.resblocks.11.ln_1.weight
139 vision_encoder.transformer.resblocks.11.ln_1.bias
140 vision_encoder.transformer.resblocks.11.attn.in_proj_weight
141 vision_encoder.transformer.resblocks.11.attn.in_proj_bias
142 vision_encoder.transformer.resblocks.11.attn.out_proj.weight
143 vision_encoder.transformer.resblocks.11.attn.out_proj.bias
144 vision_encoder.transformer.resblocks.11.ln_2.weight
145 vision_encoder.transformer.resblocks.11.ln_2.bias
146 vision_encoder.transformer.resblocks.11.mlp.c_fc.weight
147 vision_encoder.transformer.resblocks.11.mlp.c_fc.bias
148 vision_encoder.transformer.resblocks.11.mlp.c_proj.weight
149 vision_encoder.transformer.resblocks.11.mlp.c_proj.bias
150 vision_encoder.transformer.resblocks.12.ln_1.weight
151 vision_encoder.transformer.resblocks.12.ln_1.bias
152 vision_encoder.transformer.resblocks.12.attn.in_proj_weight
153 vision_encoder.transformer.resblocks.12.attn.in_proj_bias
154 vision_encoder.transformer.resblocks.12.attn.out_proj.weight
155 vision_encoder.transformer.resblocks.12.attn.out_proj.bias
156 vision_encoder.transformer.resblocks.12.ln_2.weight
157 vision_encoder.transformer.resblocks.12.ln_2.bias
158 vision_encoder.transformer.resblocks.12.mlp.c_fc.weight
159 vision_encoder.transformer.resblocks.12.mlp.c_fc.bias
160 vision_encoder.transformer.resblocks.12.mlp.c_proj.weight
161 vision_encoder.transformer.resblocks.12.mlp.c_proj.bias
162 vision_encoder.transformer.resblocks.13.ln_1.weight
163 vision_encoder.transformer.resblocks.13.ln_1.bias
164 vision_encoder.transformer.resblocks.13.attn.in_proj_weight
165 vision_encoder.transformer.resblocks.13.attn.in_proj_bias
166 vision_encoder.transformer.resblocks.13.attn.out_proj.weight
167 vision_encoder.transformer.resblocks.13.attn.out_proj.bias
168 vision_encoder.transformer.resblocks.13.ln_2.weight
169 vision_encoder.transformer.resblocks.13.ln_2.bias
170 vision_encoder.transformer.resblocks.13.mlp.c_fc.weight
171 vision_encoder.transformer.resblocks.13.mlp.c_fc.bias
172 vision_encoder.transformer.resblocks.13.mlp.c_proj.weight
173 vision_encoder.transformer.resblocks.13.mlp.c_proj.bias
174 vision_encoder.transformer.resblocks.14.ln_1.weight
175 vision_encoder.transformer.resblocks.14.ln_1.bias
176 vision_encoder.transformer.resblocks.14.attn.in_proj_weight
177 vision_encoder.transformer.resblocks.14.attn.in_proj_bias
178 vision_encoder.transformer.resblocks.14.attn.out_proj.weight
179 vision_encoder.transformer.resblocks.14.attn.out_proj.bias
180 vision_encoder.transformer.resblocks.14.ln_2.weight
181 vision_encoder.transformer.resblocks.14.ln_2.bias
182 vision_encoder.transformer.resblocks.14.mlp.c_fc.weight
183 vision_encoder.transformer.resblocks.14.mlp.c_fc.bias
184 vision_encoder.transformer.resblocks.14.mlp.c_proj.weight
185 vision_encoder.transformer.resblocks.14.mlp.c_proj.bias
186 vision_encoder.transformer.resblocks.15.ln_1.weight
187 vision_encoder.transformer.resblocks.15.ln_1.bias
188 vision_encoder.transformer.resblocks.15.attn.in_proj_weight
189 vision_encoder.transformer.resblocks.15.attn.in_proj_bias
190 vision_encoder.transformer.resblocks.15.attn.out_proj.weight
191 vision_encoder.transformer.resblocks.15.attn.out_proj.bias
192 vision_encoder.transformer.resblocks.15.ln_2.weight
193 vision_encoder.transformer.resblocks.15.ln_2.bias
194 vision_encoder.transformer.resblocks.15.mlp.c_fc.weight
195 vision_encoder.transformer.resblocks.15.mlp.c_fc.bias
196 vision_encoder.transformer.resblocks.15.mlp.c_proj.weight
197 vision_encoder.transformer.resblocks.15.mlp.c_proj.bias
198 vision_encoder.transformer.resblocks.16.ln_1.weight
199 vision_encoder.transformer.resblocks.16.ln_1.bias
200 vision_encoder.transformer.resblocks.16.attn.in_proj_weight
201 vision_encoder.transformer.resblocks.16.attn.in_proj_bias
202 vision_encoder.transformer.resblocks.16.attn.out_proj.weight
203 vision_encoder.transformer.resblocks.16.attn.out_proj.bias
204 vision_encoder.transformer.resblocks.16.ln_2.weight
205 vision_encoder.transformer.resblocks.16.ln_2.bias
206 vision_encoder.transformer.resblocks.16.mlp.c_fc.weight
207 vision_encoder.transformer.resblocks.16.mlp.c_fc.bias
208 vision_encoder.transformer.resblocks.16.mlp.c_proj.weight
209 vision_encoder.transformer.resblocks.16.mlp.c_proj.bias
210 vision_encoder.transformer.resblocks.17.ln_1.weight
211 vision_encoder.transformer.resblocks.17.ln_1.bias
212 vision_encoder.transformer.resblocks.17.attn.in_proj_weight
213 vision_encoder.transformer.resblocks.17.attn.in_proj_bias
214 vision_encoder.transformer.resblocks.17.attn.out_proj.weight
215 vision_encoder.transformer.resblocks.17.attn.out_proj.bias
216 vision_encoder.transformer.resblocks.17.ln_2.weight
217 vision_encoder.transformer.resblocks.17.ln_2.bias
218 vision_encoder.transformer.resblocks.17.mlp.c_fc.weight
219 vision_encoder.transformer.resblocks.17.mlp.c_fc.bias
220 vision_encoder.transformer.resblocks.17.mlp.c_proj.weight
221 vision_encoder.transformer.resblocks.17.mlp.c_proj.bias
222 vision_encoder.transformer.resblocks.18.ln_1.weight
223 vision_encoder.transformer.resblocks.18.ln_1.bias
224 vision_encoder.transformer.resblocks.18.attn.in_proj_weight
225 vision_encoder.transformer.resblocks.18.attn.in_proj_bias
226 vision_encoder.transformer.resblocks.18.attn.out_proj.weight
227 vision_encoder.transformer.resblocks.18.attn.out_proj.bias
228 vision_encoder.transformer.resblocks.18.ln_2.weight
229 vision_encoder.transformer.resblocks.18.ln_2.bias
230 vision_encoder.transformer.resblocks.18.mlp.c_fc.weight
231 vision_encoder.transformer.resblocks.18.mlp.c_fc.bias
232 vision_encoder.transformer.resblocks.18.mlp.c_proj.weight
233 vision_encoder.transformer.resblocks.18.mlp.c_proj.bias
234 vision_encoder.transformer.resblocks.19.ln_1.weight
235 vision_encoder.transformer.resblocks.19.ln_1.bias
236 vision_encoder.transformer.resblocks.19.attn.in_proj_weight
237 vision_encoder.transformer.resblocks.19.attn.in_proj_bias
238 vision_encoder.transformer.resblocks.19.attn.out_proj.weight
239 vision_encoder.transformer.resblocks.19.attn.out_proj.bias
240 vision_encoder.transformer.resblocks.19.ln_2.weight
241 vision_encoder.transformer.resblocks.19.ln_2.bias
242 vision_encoder.transformer.resblocks.19.mlp.c_fc.weight
243 vision_encoder.transformer.resblocks.19.mlp.c_fc.bias
244 vision_encoder.transformer.resblocks.19.mlp.c_proj.weight
245 vision_encoder.transformer.resblocks.19.mlp.c_proj.bias
246 vision_encoder.transformer.resblocks.20.ln_1.weight
247 vision_encoder.transformer.resblocks.20.ln_1.bias
248 vision_encoder.transformer.resblocks.20.attn.in_proj_weight
249 vision_encoder.transformer.resblocks.20.attn.in_proj_bias
250 vision_encoder.transformer.resblocks.20.attn.out_proj.weight
251 vision_encoder.transformer.resblocks.20.attn.out_proj.bias
252 vision_encoder.transformer.resblocks.20.ln_2.weight
253 vision_encoder.transformer.resblocks.20.ln_2.bias
254 vision_encoder.transformer.resblocks.20.mlp.c_fc.weight
255 vision_encoder.transformer.resblocks.20.mlp.c_fc.bias
256 vision_encoder.transformer.resblocks.20.mlp.c_proj.weight
257 vision_encoder.transformer.resblocks.20.mlp.c_proj.bias
258 vision_encoder.transformer.resblocks.21.ln_1.weight
259 vision_encoder.transformer.resblocks.21.ln_1.bias
260 vision_encoder.transformer.resblocks.21.attn.in_proj_weight
261 vision_encoder.transformer.resblocks.21.attn.in_proj_bias
262 vision_encoder.transformer.resblocks.21.attn.out_proj.weight
263 vision_encoder.transformer.resblocks.21.attn.out_proj.bias
264 vision_encoder.transformer.resblocks.21.ln_2.weight
265 vision_encoder.transformer.resblocks.21.ln_2.bias
266 vision_encoder.transformer.resblocks.21.mlp.c_fc.weight
267 vision_encoder.transformer.resblocks.21.mlp.c_fc.bias
268 vision_encoder.transformer.resblocks.21.mlp.c_proj.weight
269 vision_encoder.transformer.resblocks.21.mlp.c_proj.bias
270 vision_encoder.transformer.resblocks.22.ln_1.weight
271 vision_encoder.transformer.resblocks.22.ln_1.bias
272 vision_encoder.transformer.resblocks.22.attn.in_proj_weight
273 vision_encoder.transformer.resblocks.22.attn.in_proj_bias
274 vision_encoder.transformer.resblocks.22.attn.out_proj.weight
275 vision_encoder.transformer.resblocks.22.attn.out_proj.bias
276 vision_encoder.transformer.resblocks.22.ln_2.weight
277 vision_encoder.transformer.resblocks.22.ln_2.bias
278 vision_encoder.transformer.resblocks.22.mlp.c_fc.weight
279 vision_encoder.transformer.resblocks.22.mlp.c_fc.bias
280 vision_encoder.transformer.resblocks.22.mlp.c_proj.weight
281 vision_encoder.transformer.resblocks.22.mlp.c_proj.bias
282 vision_encoder.transformer.resblocks.23.ln_1.weight
283 vision_encoder.transformer.resblocks.23.ln_1.bias
284 vision_encoder.transformer.resblocks.23.attn.in_proj_weight
285 vision_encoder.transformer.resblocks.23.attn.in_proj_bias
286 vision_encoder.transformer.resblocks.23.attn.out_proj.weight
287 vision_encoder.transformer.resblocks.23.attn.out_proj.bias
288 vision_encoder.transformer.resblocks.23.ln_2.weight
289 vision_encoder.transformer.resblocks.23.ln_2.bias
290 vision_encoder.transformer.resblocks.23.mlp.c_fc.weight
291 vision_encoder.transformer.resblocks.23.mlp.c_fc.bias
292 vision_encoder.transformer.resblocks.23.mlp.c_proj.weight
293 vision_encoder.transformer.resblocks.23.mlp.c_proj.bias
294 vision_encoder.ln_post.weight
295 vision_encoder.ln_post.bias
296 perceiver.latents
anas-awadalla commented 1 year ago

This is super useful! I am looking at the vision encoder's forward pass from OpenCLIP. We are using the tokens and not the pooled output. I think in this case we may be running into a case where the ln_post parameters are not used in the forward pass.

I am still unsure about why the cls embedding is not being used. It clearly is here. Which CLIP model are you using? Maybe that has something to do with it? In any case might not be harmful at all to freeze cls embedding as well although a bit hacky.

ElegantLin commented 1 year ago

Sorry that I forgot to reply!

Thanks a lot! Very helpful!