Open yinjun622 opened 1 month ago
Hey, yes, +1 for the above comment for llava + llama3.1.
Hey, yes, +1 for the above comment for llava + llama3.1.
+1024 ... I'm waiting llava + llama3.1
it seems to work i ran this. however you have to upgrade transformers and pytorch and deepspeed
from mmengine.dataset import DefaultSampler from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, LoggerHook, ParamSchedulerHook) from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR from torch.optim import AdamW from transformers import (AutoModelForCausalLM, AutoTokenizer, SiglipImageProcessor, SiglipVisionModel)
from xtuner.dataset import LLaVADataset from xtuner.dataset.collate_fns import default_collate_fn from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook from xtuner.engine.runner import TrainLoop from xtuner.model import LLaVAModel from xtuner.utils import PROMPT_TEMPLATE
#######################################################################
#######################################################################
llm_name_or_path = 'Meta-Llama/Meta-Llama-3.1-8B-Instruct-abliterated' visual_encoder_name_or_path = 'google/siglip-so400m-patch14-384'
data_root = './data/llava_data/' data_path = data_root + 'LLaVA-Pretrain/blip_laion_cc_sbu_558k.json' image_folder = data_root + 'LLaVA-Pretrain/images' prompt_template = PROMPT_TEMPLATE.llama3_chat max_length = int(131072 - (336 / 14)**2)
batch_size = 1 # per_device accumulative_counts = 1 dataloader_num_workers = 5 max_epochs = 1 optim_type = AdamW lr = 1e-3 betas = (0.9, 0.999) weight_decay = 0 max_norm = 1 # grad clip warmup_ratio = 0.03
save_steps = 500 save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
evaluation_freq = 500 SYSTEM = '' evaluation_images = 'https://llava-vl.github.io/static/images/view.jpg' evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture']
#######################################################################
####################################################################### tokenizer = dict( type=AutoTokenizer.from_pretrained, pretrained_model_name_or_path=llm_name_or_path, trust_remote_code=True, padding_side='right')
image_processor = dict( type=SiglipImageProcessor.from_pretrained, pretrained_model_name_or_path=visual_encoder_name_or_path, trust_remote_code=True)
model = dict( type=LLaVAModel, freeze_llm=True, freeze_visual_encoder=True, llm=dict( type=AutoModelForCausalLM.from_pretrained, pretrained_model_name_or_path=llm_name_or_path, trust_remote_code=True), visual_encoder=dict( type=SiglipVisionModel.from_pretrained, pretrained_model_name_or_path=visual_encoder_name_or_path))
#######################################################################
####################################################################### llava_dataset = dict( type=LLaVADataset, data_path=data_path, image_folder=image_folder, tokenizer=tokenizer, image_processor=image_processor, dataset_map_fn=llava_map_fn, template_map_fn=dict( type=template_map_fn_factory, template=prompt_template), max_length=max_length, pad_image_to_square=False)
train_dataloader = dict( batch_size=batch_size, num_workers=dataloader_num_workers, pin_memory=True, dataset=llava_dataset, sampler=dict(type=DefaultSampler, shuffle=True), collate_fn=dict(type=default_collate_fn))
#######################################################################
#######################################################################
optim_wrapper = dict( type=AmpOptimWrapper, optimizer=dict( type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), accumulative_counts=accumulative_counts, loss_scale='dynamic', dtype='float16')
param_scheduler = [ dict( type=LinearLR, start_factor=1e-5, by_epoch=True, begin=0, end=warmup_ratio max_epochs, convert_to_iter_based=True), dict( type=CosineAnnealingLR, eta_min=0.0, by_epoch=True, begin=warmup_ratio max_epochs, end=max_epochs, convert_to_iter_based=True) ]
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
#######################################################################
#######################################################################
custom_hooks = [ dict(type=DatasetInfoHook, tokenizer=tokenizer), dict( type=EvaluateChatHook, tokenizer=tokenizer, image_processor=image_processor, every_n_iters=evaluation_freq, evaluation_inputs=evaluation_inputs, evaluation_images=evaluation_images, system=SYSTEM, prompt_template=prompt_template) ]
default_hooks = dict(
timer=dict(type=IterTimerHook),
# print log every 10 iterations.
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
# enable the parameter scheduler.
param_scheduler=dict(type=ParamSchedulerHook),
# save checkpoint per `save_steps`.
checkpoint=dict(
type=CheckpointHook,
by_epoch=False,
interval=save_steps,
max_keep_ckpts=save_total_limit),
# set sampler seed in distributed environment.
sampler_seed=dict(type=DistSamplerSeedHook),
)
env_cfg = dict(
cudnn_benchmark=False,
# set multi process parameters
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
# set distributed parameters
dist_cfg=dict(backend='nccl'),
)
visualizer = None
log_level = 'INFO'
load_from = None
resume = False
deterministic
randomness = dict(seed=None, deterministic=False)
log_processor = dict(by_epoch=False)
hope can use llama3.1 soon on ollama