huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.18k stars 27.05k forks source link

Have `_is_peft_model` check if there's any peft submodule/Allow quantised training #30878

Closed ambroser53 closed 2 months ago

ambroser53 commented 6 months ago

Feature request

I have multi-modal models with multiple different peft models on different submodules due to requiring different LoRA configurations. The following is an example:

model = AutoModelForVision2Seq.from_pretrained(
  args.pretrained_ckpt,
  torch_dtype=compute_dtype,
  quantization_config=BitsAndBytesConfig(
      load_in_4bit=bits == 4,
      load_in_8bit=bits == 8,
      llm_int8_threshold=6.0,
      int8_quant_skip_modules=int8_quant_skip_modules,
      llm_int8_has_fp16_weight=False,
      bnb_4bit_compute_dtype=compute_dtype,
      bnb_4bit_use_double_quant=True,
      bnb_4bit_quant_type='nf4'  # {'fp4', 'nf4'}
  ) if bits < 16 else None,
  attn_implementation=args.attn_implementation,
)

if (args.use_lora and not resume_from_checkpoint and not ft_checkpoint_dir):
  target_modules = get_target_modules(model.model.text_model, args, bits)
  peft_config = LoraConfig(
      target_modules=target_modules,
      inference_mode=args.inference_mode,
      r=args.lora_r,
      lora_alpha=args.lora_alpha,
      lora_dropout=args.lora_dropout,
      use_dora=args.use_dora
  )
  model.model.text_model = get_peft_model(model.model.text_model, peft_config)

  if args.vit_train:
      target_modules = get_target_modules(model.model.vision_model, args, args.vit_bits, vit=True)
      peft_config = LoraConfig(
          target_modules=target_modules,
          inference_mode=args.inference_mode,
          r=args.vit_lora_r,
          lora_alpha=args.vit_lora_alpha,
          lora_dropout=args.lora_dropout,
          use_dora=args.use_dora_vit
      )
      model.model.vision_model = get_peft_model(model.model.vision_model, peft_config)

  if args.lora_abstractor:
      target_modules = get_target_modules(model.model.connector, args, args.bits)
      peft_config = LoraConfig(
          target_modules=target_modules,
          inference_mode=args.inference_mode,
          r=args.lora_r,
          lora_alpha=args.lora_alpha,
          lora_dropout=args.lora_dropout,
          use_dora=args.use_dora
      )
      model.model.connector = get_peft_model(model.model.connector, peft_config)

This works fine until you quantise the model as the huggingface trainer requires that the model pass the _is_peft_model check when it's quantised or it assumes the entire model is quantised and therefore untrainable. A really simple fix would just to be to change it from a check with _is_peft_model to just seeing if there are any parameters that are unquantised AND of a trainable datatype (>=16bit) but if it really should stick to being a peft model check it could just check if there are any submodules which are peft models. I think in general, allowing functionality to generalise irrespective of what class or model type the top level model is will just generally be helpful for researchers in the multi-modal space such as myself.

I understand that there is PeftMixedModel but it seems largely unsupported and unwieldy compared to just doing it like this where I have complete control.

Motivation

Let me train quantised models that aren't PeftModels at their top level/aren't wrapped as a PeftModel.

Your contribution

Here's a version of _is_peft_model that uses param requiring gradient and not being 4 or 8 bit

def _is_peft_model(model):
    if is_peft_available():
        classes_to_check = (PeftModel,) if is_peft_available() else ()
        # Here we also check if the model is an instance of `PeftMixedModel` introduced in peft>=0.7.0: https://github.com/huggingface/transformers/pull/28321
        if version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0"):
            from peft import PeftMixedModel

            classes_to_check = (*classes_to_check, PeftMixedModel)
        is_peft = isinstance(model, classes_to_check)
        if is_peft:
            return True
        else:
            for _, param in model.named_parameters():
                if param.requires_grad and not "4bit" not in param.__class__.__name__ and "8bit" not in param.__class__.__name__:
                    return True
    return False

Here is probably a more acceptable version that specifically checks for any PeftModel submodules:

def _is_peft_model(model):
    if is_peft_available():
        classes_to_check = (PeftModel,) if is_peft_available() else ()
        # Here we also check if the model is an instance of `PeftMixedModel` introduced in peft>=0.7.0: https://github.com/huggingface/transformers/pull/28321
        if version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0"):
            from peft import PeftMixedModel

            classes_to_check = (*classes_to_check, PeftMixedModel)
        is_peft = isinstance(model, classes_to_check)
        if is_peft:
            return True
        else:
            for submodule in model.modules():
                if isinstance(submodule, classes_to_check):
                    return True
    return False
amyeroberts commented 6 months ago

cc @younesbelkada @pacman100

younesbelkada commented 6 months ago

Hi ! This makes sense yes, can you open a PR with the suggested changes ? 🙏

github-actions[bot] commented 3 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.