InternLM / xtuner

An efficient, flexible and full-featured toolkit for fine-tuning LLM (InternLM2, Llama3, Phi3, Qwen, Mistral, ...)
https://xtuner.readthedocs.io/zh-cn/latest/
Apache License 2.0
3.69k stars 299 forks source link

llava-llama3-8b llm+llm adapter merge error #662

Open ztfmars opened 4 months ago

ztfmars commented 4 months ago

i use llava_llama3_8b_instruct_qlora_clip_vit_large_p14_336_lora_e1_finetune.py to fineture on my dataset, and want to get a llava-llama38b multimodal model on my datasets. after training and pth -> hf,
i got llm adapter, visual encoder adapter ,project. image

but i can't merge llm +llm adapter together and can'get the LLM weights as turial https://github.com/InternLM/xtuner/tree/main/xtuner/configs/llava/llama3_8b_instruct_clip_vit_large_p14_336

image

the error can be listed as following:

/llava_train_20240506$ xtuner convert merge /home/fusionai/.cache/modelscope/hub/LLM-Research/Meta-Llama-3-8B-                Instruct /home/fusionai/project/internllm_demo/llama3/llama3-ft/llava_train_20240506/iter_1000_hf/llm_adapter /home/fusionai/project/internllm_demo/llama3/llama3-ft/llava_train_                20240506/iter_1000_llava
[2024-05-08 09:51:48,946] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.1
 [WARNING]  using untested triton version (2.1.0), only 1.0.0 is known to be compatible
[2024-05-08 09:51:53,816] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.1
 [WARNING]  using untested triton version (2.1.0), only 1.0.0 is known to be compatible
Loading checkpoint shards:  75%|██████████████████████████████████████████████████████████████████████████████████████▎                            |                                             Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████|                                             Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████|                                             4/4 [00:05<00:00,  1.25s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Traceback (most recent call last):
  File "/home/fusionai/project/internllm/xtuner/xtuner/tools/model_converters/merge.py", line 73, in <module>
    main()
  File "/home/fusionai/project/internllm/xtuner/xtuner/tools/model_converters/merge.py", line 56, in main
    model_unmerged = PeftModel.from_pretrained(
  File "/home/fusionai/anaconda3/envs/llama3/lib/python3.10/site-packages/peft/peft_model.py", line 324, in from_pretrained
    config = PEFT_TYPE_TO_CONFIG_MAPPING[
  File "/home/fusionai/anaconda3/envs/llama3/lib/python3.10/site-packages/peft/config.py", line 151, in from_pretrained
    return cls.from_peft_type(**kwargs)
  File "/home/fusionai/anaconda3/envs/llama3/lib/python3.10/site-packages/peft/config.py", line 118, in from_peft_type
    return config_cls(**kwargs)
TypeError: LoraConfig.__init__() got an unexpected keyword argument 'layer_replication'

additon description:

from xtuner.dataset import LLaVADataset from xtuner.dataset.collate_fns import default_collate_fn from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory from xtuner.dataset.samplers import LengthGroupedSampler from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook from xtuner.engine.runner import TrainLoop from xtuner.model import LLaVAModel from xtuner.utils import PROMPT_TEMPLATE

#######################################################################

PART 1 Settings

#######################################################################

Model

llm_name_or_path = '/home/fusionai/.cache/modelscope/hub/LLM-Research/Meta-Llama-3-8B-Instruct' visual_encoder_name_or_path = '/home/fusionai/.cache/modelscope/hub/AI-ModelScope/clip-vit-large-patch14-336'

Specify the pretrained pth

pretrained_pth = '/home/fusionai/project/internllm_demo/llama3/pretrained-model/llama3-llava-iter_2181.pth' # noqa: E501

Data

data_root = '/home/fusionai/project/datasets/llama3_test001/' data_path = data_root + 'repeated_data.json' image_folder = data_root prompt_template = PROMPT_TEMPLATE.llama3_chat max_length = int(2048 - (336 / 14)**2)

Scheduler & Optimizer

batch_size = 1 # per_device accumulative_counts = 1 dataloader_num_workers = 0 max_epochs = 1 optim_type = AdamW lr = 2e-4 betas = (0.9, 0.999) weight_decay = 0 max_norm = 1 # grad clip warmup_ratio = 0.03

Save

save_steps = 500 save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)

Evaluate the generation performance during the training

evaluation_freq = 500 SYSTEM = '' evaluation_images = '/home/fusionai/project/datasets/llama3_test001/imgs/test0001.png' evaluation_inputs = ['此图表示什么逻辑?','图中都有哪些逻辑符号?']

#######################################################################

PART 2 Model & Tokenizer & Image Processor

####################################################################### tokenizer = dict( type=AutoTokenizer.from_pretrained, pretrained_model_name_or_path=llm_name_or_path, trust_remote_code=True, padding_side='right')

image_processor = dict( type=CLIPImageProcessor.from_pretrained, pretrained_model_name_or_path=visual_encoder_name_or_path, trust_remote_code=True)

model = dict( type=LLaVAModel, freeze_llm=True, freeze_visual_encoder=True, pretrained_pth=pretrained_pth, llm=dict( type=AutoModelForCausalLM.from_pretrained, pretrained_model_name_or_path=llm_name_or_path, trust_remote_code=True, torch_dtype=torch.float16, quantization_config=dict( type=BitsAndBytesConfig, load_in_4bit=True, load_in_8bit=False, llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type='nf4')), llm_lora=dict( type=LoraConfig, r=512, lora_alpha=256, lora_dropout=0.05, bias='none', task_type='CAUSAL_LM'), visual_encoder=dict( type=CLIPVisionModel.from_pretrained, pretrained_model_name_or_path=visual_encoder_name_or_path), visual_encoder_lora=dict( type=LoraConfig, r=64, lora_alpha=16, lora_dropout=0.05, bias='none'))

#######################################################################

PART 3 Dataset & Dataloader

####################################################################### llava_dataset = dict( type=LLaVADataset, data_path=data_path, image_folder=image_folder, tokenizer=tokenizer, image_processor=image_processor, dataset_map_fn=llava_map_fn, template_map_fn=dict( type=template_map_fn_factory, template=prompt_template), max_length=max_length, pad_image_to_square=True)

train_dataloader = dict( batch_size=batch_size, num_workers=dataloader_num_workers, dataset=llava_dataset, sampler=dict( type=LengthGroupedSampler, length_property='modality_length', per_device_batch_size=batch_size * accumulative_counts), collate_fn=dict(type=default_collate_fn))

#######################################################################

PART 4 Scheduler & Optimizer

#######################################################################

optimizer

optim_wrapper = dict( type=AmpOptimWrapper, optimizer=dict( type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), accumulative_counts=accumulative_counts, loss_scale='dynamic', dtype='float16')

learning policy

More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501

param_scheduler = [ dict( type=LinearLR, start_factor=1e-5, by_epoch=True, begin=0, end=warmup_ratio max_epochs, convert_to_iter_based=True), dict( type=CosineAnnealingLR, eta_min=0.0, by_epoch=True, begin=warmup_ratio max_epochs, end=max_epochs, convert_to_iter_based=True) ]

train, val, test setting

train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)

#######################################################################

PART 5 Runtime

#######################################################################

Log the dialogue periodically during the training process, optional

custom_hooks = [ dict(type=DatasetInfoHook, tokenizer=tokenizer), dict( type=EvaluateChatHook, tokenizer=tokenizer, image_processor=image_processor, every_n_iters=evaluation_freq, evaluation_inputs=evaluation_inputs, evaluation_images=evaluation_images, system=SYSTEM, prompt_template=prompt_template) ]

configure default hooks

default_hooks = dict(

record the time of every iteration.

timer=dict(type=IterTimerHook),
# print log every 10 iterations.
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
# enable the parameter scheduler.
param_scheduler=dict(type=ParamSchedulerHook),
# save checkpoint per `save_steps`.
checkpoint=dict(
    type=CheckpointHook,
    by_epoch=False,
    interval=save_steps,
    max_keep_ckpts=save_total_limit),
# set sampler seed in distributed evrionment.
sampler_seed=dict(type=DistSamplerSeedHook),

)

configure environment

env_cfg = dict(

whether to enable cudnn benchmark

cudnn_benchmark=False,
# set multi process parameters
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
# set distributed parameters
dist_cfg=dict(backend='nccl'),

)

set visualizer

visualizer = None

set log level

log_level = 'INFO'

load from which checkpoint

load_from = None

whether to resume training from the loaded checkpoint

resume = False

Defaults to use random seed and disable deterministic

randomness = dict(seed=None, deterministic=False)

set log processor

log_processor = dict(by_epoch=False)



**how to solve this problem, waiting for help! 
thx**
LZHgrla commented 4 months ago

@ztfmars Hi

This issues is caused by the mismatch between the version of transformers and peft.

This PR https://github.com/huggingface/peft/pull/1368/files supports the layer_replication for LoraConfig, so we recommend that you can update your peft to v0.10.0 and re-run your merge script.

ztfmars commented 4 months ago

@ztfmars Hi

This issues is caused by the mismatch between the version of transformers and peft.

This PR https://github.com/huggingface/peft/pull/1368/files supports the layer_replication for LoraConfig, so we recommend that you can update your peft to v0.10.0 and re-run your merge script.

yes, it works! but it have obvious version conflicts between xtuner and lmdeploy on peft, i will try to install another venv env for lmdeploy again and continue.

image

thx very much!