mlfoundations / open_flamingo

An open-source framework for training large multimodal models.
MIT License
3.68k stars 277 forks source link

Mismatch input type and weight type when training with precision fp16 #260

Open hungvo304ml opened 1 year ago

hungvo304ml commented 1 year ago

Hi, thanks for making this project public.

I am trying to run training with fp16 and get the following error:

RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same

I am able to run using fp32 successfully only with an OOM error.

Traceback for error when using fp16:

Traceback (most recent call last):                                                                                                                                                                            
  File "/home/hqvo2/Projects/MultiMEDal_multimodal_medical/libs/open_flamingo/open_flamingo/train/train.py", line 484, in <module>                                                                            
    main()                                                                                                                                                                                                    
  File "/home/hqvo2/Projects/MultiMEDal_multimodal_medical/libs/open_flamingo/open_flamingo/train/train.py", line 465, in main                                                                                
    train_one_epoch(                                                                                                                                                                                          
  File "/home/hqvo2/Projects/MultiMEDal_multimodal_medical/libs/open_flamingo/open_flamingo/train/train_utils.py", line 111, in train_one_epoch                                                               
    loss_laion = model(                                                                                                                                                                                       
  File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl                                                                            
    return forward_call(*args, **kwargs)                                                                                                                                                                      
  File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward                                                                         
    output = self._run_ddp_forward(*inputs, **kwargs)                                                                                                                                                         
  File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward                                                                
    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]                                                                                                                                      
  File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl                                                                            
    return forward_call(*args, **kwargs)                                                                                                                                                                      
  File "/home/hqvo2/Projects/MultiMEDal_multimodal_medical/libs/open_flamingo/open_flamingo/src/flamingo.py", line 108, in forward                                                                            
    self._encode_vision_x(vision_x=vision_x)                                                                                                                                                                  
  File "/home/hqvo2/Projects/MultiMEDal_multimodal_medical/libs/open_flamingo/open_flamingo/src/flamingo.py", line 195, in _encode_vision_x                                                                   
    vision_x = self.vision_encoder(vision_x)[1]                                                                                                                                                               
  File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl                                                                            
    return forward_call(*args, **kwargs)                                                                                                                                                                      
  File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/open_clip/transformer.py", line 469, in forward                                                                                  
    x = self.conv1(x)  # shape = [*, width, grid, grid]                                                                                                                                                       
  File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl                                                                            
    return forward_call(*args, **kwargs)                                                                                                                                                                      
  File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 463, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same

Environment

I am using python 3.9.17 with V100 GPUs.

open-clip-torch          2.16.0
torch                    2.0.1
torchvision              0.15.2
transformers             4.28.1
anas-awadalla commented 1 year ago

Thanks for bringing this up! I will take a closer look later today. I do want to point out that we haven't gotten good performance with pure fp16 training. It could be more better if you use fp32 but use fsdp to shard model state across your GPUs rather than reducing the precision.

hungvo304ml commented 1 year ago

Thanks for clarifying. FSDP would be ideal. Still, I have problems training with FSDP. Namely, I am using MPT-1B and it does not have the get_output_embeddings and set_output_embeddings methods. I see there is a major refactor that is in progress. Looking forward to using it soon.

anas-awadalla commented 1 year ago

Got it. There is this version of mpt I use for testing if you want to give fsdp a shot before the new refactor is merged.

hungvo304ml commented 1 year ago

Great, thanks for bringing up this. I will give it a try on this model with fsdp.

hungvo304ml commented 1 year ago

I tried fsdp with "mpt-1b-redpajama-200b-hf-style" and it could pass the above error.

However, I get another error where the shape of input embeddings (self.transformer.wte.weight) has been altered. I believe it should be a 2-D tensor of shape (:, 2048) instead of a 1-D tensor of shape (25743360) which causes the size mismatch when computing the logits. More details below:

File "/home/hqvo2/Projects/MultiMEDal_multimodal_medical/libs/open_flamingo/open_flamingo/src/flamingo.py", line 111, in forward                                                                            
    output = self.lang_encoder(
File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl                                                                            
    return forward_call(*args, **kwargs)
File "/home/hqvo2/Projects/MultiMEDal_multimodal_medical/libs/open_flamingo/open_flamingo/src/flamingo_lm.py", line 157, in forward                                                                         
    return super().forward(**kwargs)  # Call the other parent's forward method    
File "/home/hqvo2/.cache/huggingface/modules/transformers_modules/anas-awadalla/mpt-1b-redpajama-200b-hf-style/f40a2c7f92621be8b12a01ac9214d3ed4ef50f60/mosaic_gpt.py", line 379, in forward                
    logits = F.linear(x, self.transformer.wte.weight, None)                                                                                                                                                   
RuntimeError: size mismatch, got 8, 8x2048,25743360 
alyakin314 commented 2 months ago

I tried fsdp with "mpt-1b-redpajama-200b-hf-style" and it could pass the above error.

However, I get another error where the shape of input embeddings (self.transformer.wte.weight) has been altered. I believe it should be a 2-D tensor of shape (:, 2048) instead of a 1-D tensor of shape (25743360) which causes the size mismatch when computing the logits. More details below:

File "/home/hqvo2/Projects/MultiMEDal_multimodal_medical/libs/open_flamingo/open_flamingo/src/flamingo.py", line 111, in forward                                                                            
    output = self.lang_encoder(
File "/home/hqvo2/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl                                                                            
    return forward_call(*args, **kwargs)
File "/home/hqvo2/Projects/MultiMEDal_multimodal_medical/libs/open_flamingo/open_flamingo/src/flamingo_lm.py", line 157, in forward                                                                         
    return super().forward(**kwargs)  # Call the other parent's forward method    
File "/home/hqvo2/.cache/huggingface/modules/transformers_modules/anas-awadalla/mpt-1b-redpajama-200b-hf-style/f40a2c7f92621be8b12a01ac9214d3ed4ef50f60/mosaic_gpt.py", line 379, in forward                
    logits = F.linear(x, self.transformer.wte.weight, None)                                                                                                                                                   
RuntimeError: size mismatch, got 8, 8x2048,25743360 

did you resolve this? i get a very similar error while trying to use fsdp w/ openflamingo 9B:

  File "/gpfs/data/oermannlab/users/alyaka01/.conda/envs/cns-flamingo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/gpfs/data/oermannlab/users/alyaka01/.conda/envs/cns-flamingo/lib/python3.9/site-packages/open_flamingo/src/flamingo.py", line 111, in forward
    output = self.lang_encoder(
  File "/gpfs/data/oermannlab/users/alyaka01/.conda/envs/cns-flamingo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/gpfs/data/oermannlab/users/alyaka01/.conda/envs/cns-flamingo/lib/python3.9/site-packages/open_flamingo/src/flamingo_lm.py", line 157, in forward
    return super().forward(**kwargs)  # Call the other parent's forward method
  File "/gpfs/data/oermannlab/users/alyaka01/.cache/huggingface/modules/transformers_modules/anas-awadalla/mpt-7b/b772e556c8e8a17d087db6935e7cd019e5eefb0f/modeling_mpt.py", line 258, in forward
    logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight)
RuntimeError: size mismatch, got 8192, 8192x4096,51486720
alyakin314 commented 2 months ago

related: https://github.com/mlfoundations/open_flamingo/issues/129#issuecomment-1696570150