mlfoundations / open_flamingo

An open-source framework for training large multimodal models.
MIT License
3.75k stars 284 forks source link

fsdp Error report #232

Open liyongqi67 opened 1 year ago

liyongqi67 commented 1 year ago

Thanks for this wonderful project. I used the following script to train the model.

torchrun --nnodes=1 --nproc_per_node=2 /home/share/yongqi/project/AutoregressiveImageRetrieval/code/open_flamingo/open_flamingo/train/finetuning.py \
  --lm_path anas-awadalla/mpt-1b-redpajama-200b \
  --tokenizer_path anas-awadalla/mpt-1b-redpajama-200b \
  --cross_attn_every_n_layers 1 \
  --dataset_resampled \
  --batch_size_mmc4 1 \
  --train_num_samples_mmc4 150000\
  --workers=2 \
  --run_name OpenFlamingo-3B-vitl-mpt1b \
  --num_epochs 2 \
  --warmup_steps  2000 \
  --mmc4_textsim_threshold 0.01 \
  --laion_shards  \
  --mmc4_shards \
  --logging_steps 1 \
  --mmc4_max_num_images 1 \
  --precision fp16 \
  --fsdp

However, if I set the fsdp flag, it will report an error as follows:

AttributeError: 'MosaicGPT' object has no attribute 'set_output_embeddings'

location: flamingo.py line 294.

If I remove this flag, there is no error. Do you have any idea about this?

i-gao commented 1 year ago

Ah, sorry about that! The issue is from this line of the FSDP wrapping function. The MPT models are still missing some standard HF Transformers functions.

Would it work for your use case to comment out the aforementioned line from our codebase? The output embedding weight will then not be sharded. Alternatively, we can add a hack to get around this similar to this part of the code for MPT-1B.

liyongqi67 commented 1 year ago

//Thanks for your quick reply!

  1. If I only comment out the corresponding line "self.lang_encoder.set_output_embeddings( wrap(wrap(self.lang_encoder.get_output_embeddings())) )", it will report another error:

    File "/home/yongqi/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    loss_mmc4 = model(
    File "/home/yongqi/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
    File "/home/share/yongqi/project/AutoregressiveImageRetrieval/code/open_flamingo/open_flamingo/src/flamingo.py", line 111, in forward
    return forward_call(*args, **kwargs)
    File "/home/share/yongqi/project/AutoregressiveImageRetrieval/code/open_flamingo/open_flamingo/src/flamingo.py", line 111, in forward
    output = self.lang_encoder(
    File "/home/yongqi/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    output = self.lang_encoder(
    File "/home/yongqi/miniconda3/envs/openflamingo/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
    File "/home/share/yongqi/project/AutoregressiveImageRetrieval/code/open_flamingo/open_flamingo/src/flamingo_lm.py", line 157, in forward
    return forward_call(*args, **kwargs)
    File "/home/share/yongqi/project/AutoregressiveImageRetrieval/code/open_flamingo/open_flamingo/src/flamingo_lm.py", line 157, in forward
    return super().forward(**kwargs)  # Call the other parent's forward method
    return super().forward(**kwargs)  # Call the other parent's forward method  File "/home/yongqi/.cache/huggingface/modules/transformers_modules/anas-awadalla/mpt-1b-redpajama-200b/bfa38d4f431e091fe599d7b4cdb62972532f3c7c/mosaic_gpt.py", line 366, in forward
    
    File "/home/yongqi/.cache/huggingface/modules/transformers_modules/anas-awadalla/mpt-1b-redpajama-200b/bfa38d4f431e091fe599d7b4cdb62972532f3c7c/mosaic_gpt.py", line 366, in forward
        logits = F.linear(x, self.transformer.wte.weight, None)
    RuntimeErrorRuntimeError: : size mismatch, got 15, 15x2048,51486720

    It is very strange.

  2. If I want to set the set_output_embeddings() function in open_flamingo. How to do it? I notice there is a corresponding code in https://huggingface.co/mosaicml/mpt-7b/blob/main/modeling_mpt.py. I copy this code into open_flamingosrc/factory.py?
    def get_input_embeddings(self):
      return self.transformer.wte
    def set_input_embeddings(self, value):
      self.transformer.wte = value
    def get_output_embeddings(self):
       return self.transformer.wte
    def set_output_embeddings(self, new_embeddings):
      self.transformer.wte = new_embeddings

    But it still reports the above mismatch error. Could you update how to set the get_output_embeddings() and set_output_embeddings()?

hungvo304ml commented 1 year ago

Have you solved the issue? I have the same problem when training with fsdp.