Closed YerongLi closed 1 year ago
In short are there any ways to directly quantize this model?
class Flamingo(nn.Module):
def __init__(
self,
vision_encoder: nn.Module,
lang_encoder: nn.Module,
eoc_token_id: int,
media_token_id: int,
vis_dim: int,
cross_attn_every_n_layers: int = 1,
gradient_checkpointing: bool = False,
):
"""
Args:
vision_encoder (nn.Module): HF CLIPModel
lang_encoder (nn.Module): HF causal language model
eoc_token_id (int): Token id for <|endofchunk|>
media_token_id (int): Token id for <image>
vis_dim (int): Dimension of the visual features.
Visual features are projected to match this shape along the last dimension.
cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1.
"""
super().__init__()
self.eoc_token_id = eoc_token_id
self.media_token_id = media_token_id
self.vis_dim = vis_dim
if hasattr(lang_encoder.config, "d_model"):
self.lang_dim = lang_encoder.config.d_model # mpt uses d_model
else:
self.lang_dim = lang_encoder.config.hidden_size
self.vision_encoder = vision_encoder.visual
self.perceiver = PerceiverResampler(dim=self.vis_dim)
self.lang_encoder = lang_encoder
self.lang_encoder.init_flamingo(
media_token_id=media_token_id,
lang_hidden_size=self.lang_dim,
vis_hidden_size=self.vis_dim,
cross_attn_every_n_layers=cross_attn_every_n_layers,
gradient_checkpointing=gradient_checkpointing,
)
self._use_gradient_checkpointing = gradient_checkpointing
self.perceiver._use_gradient_checkpointing = gradient_checkpointing
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
System Info
https://github.com/open-mmlab/Multimodal-GPT
Are there any good ways to quantize open-flamingo? I found after using prepare_model_for_kbit_training the flamingo_init() is reverted
Here is a modified https://github.com/open-mmlab/Multimodal-GPT/blob/main/mmgpt/models/open_flamingo/builder.py
Who can help?
No response
Information
Tasks
examples
folderReproduction
language_datasets = [ dict( type="alpaca_gpt4", ann_path="data/alpaca_gpt4/alpaca_gpt4_data.json", ),
]
torchrun --nproc_per_node=1 mmgpt/train/instruction_finetune.py \ --lm_path /scratch/yerong/.cache/pyllama/hf/7B \ --tokenizer_path /scratch/yerong/.cache/pyllama/hf/7B \ --pretrained_path OpenFlamingo-9B/checkpoint.pt \ --run_name train-my-gpt4 \ --learning_rate 1e-5 \ --lr_scheduler cosine \ --batch_size 1\ --tuning_config configs/lora_config.py \ --dataset_config configs/dataset_config.py
Total training steps: 99883
Found no checkpoints for run train-my-gpt4.
0%| | 0/99883 [00:00<?, ?it/s]/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None
warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
0%| | 0/99883 [00:01<?, ?it/s]
Traceback (most recent call last):
File "/scratch/yerong/Multimodal-GPT/mmgpt/train/instruction_finetune.py", line 470, in
main()
File "/scratch/yerong/Multimodal-GPT/mmgpt/train/instruction_finetune.py", line 302, in main
train_one_epoch(
File "/scratch/yerong/Multimodal-GPT/mmgpt/train/instruction_finetune.py", line 390, in train_one_epoch
loss_batch = model(
File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, kwargs)
File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
output = self._run_ddp_forward(*inputs, *kwargs)
File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
return module_to_run(inputs[0], kwargs[0]) # type: ignore[index]
File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, kwargs)
File "/scratch/yerong/Multimodal-GPT/mmgpt/models/open_flamingo/flamingo.py", line 104, in forward
output = self.lang_encoder(
File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, *kwargs)
File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/peft/peft_model.py", line 416, in forward
return self.get_base_model()(args, kwargs)
File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, kwargs)
File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, *kwargs)
File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 688, in forward
outputs = self.model(
File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(args, kwargs)
File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, kwargs)
File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 570, in forward
layer_outputs = torch.utils.checkpoint.checkpoint(
File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 249, in checkpoint
return CheckpointFunction.apply(function, preserve, args) File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/torch/autograd/function.py", line 506, in apply
return super().apply(args, kwargs) # type: ignore[misc] File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 107, in forward
outputs = run_function(args)
File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 566, in custom_forward
return module(inputs, output_attentions, None) File "/scratch/yerong/.conda/envs/mmgpt/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
TypeError: forward() takes from 2 to 3 positional arguments but 7 were given