Open eljandoubi opened 1 week ago
cc @muellerzr and @SunMarc
Hi! Can you please show your entire script you are running?
Training script for Vision2Seq model-like.
"""
import logging
from os import getenv
from typing import Union
from torch import cuda
from datasets import load_dataset, DatasetDict
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, set_seed, \
AutoProcessor, AutoModelForVision2Seq, HfArgumentParser, EarlyStoppingCallback, \
TrainingArguments, Trainer, BitsAndBytesConfig
from accelerate import PartialState#, logging
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from peft.optimizers import create_loraplus_optimizer
from utils.shared_dataclass import SharedArugments
import utils.args_parsers as parsers
import utils.data_processing as process
import utils.model_preparation as prep
from utils.post_process_args import post_process, get_sorted_key
from utils.eval_metrics import EvalMetrics
main_device = "cuda" if cuda.is_available() else "cpu"
main_state = PartialState(cpu=main_device=="cpu")
idx = main_state.process_index
logger = logging.getLogger(__name__)
#logger = logging.get_logger(__name__)
FORMAT = f'process {idx} %(levelname)s %(asctime)s,%(msecs)d %(message)s'
logging.basicConfig(level=logging.INFO,
format=FORMAT,
datefmt='%H:%M:%S')
def main():
"""Train Vision2Seq model-like"""
logger.info("Start train script using \n%s",main_state)
load_datasets_data_class = parsers.method_to_dataclass(load_dataset,
"load_dataset")
early_stop_data_class = parsers.method_to_dataclass(EarlyStoppingCallback,
"EarlyStopping")
train_type = getenv('TRAIN_TYPE')
if train_type == "SEQ2SEQ":
logger.info("Training type is SEQ2SEQ")
train_args_class = Seq2SeqTrainingArguments
trainer_class = Seq2SeqTrainer
else:
logger.info("Training type is %s",train_type)
train_type = "CAUSAL"
train_args_class = TrainingArguments
trainer_class = Trainer
dataclass_instances = [
load_datasets_data_class,
early_stop_data_class,
SharedArugments,
train_args_class
]
if getenv("USE_LORA") == "True":
logger.info("LORA is used.")
use_lora =True
lora_cfg = parsers.method_to_dataclass(LoraConfig,
"LoraConfig",
["revision"])
dataclass_instances.append(lora_cfg)
else:
logger.info("LORA is NOT used.")
use_lora = False
if getenv("USE_BITS") == "True":
logger.info("BitsAndBytes is used.")
use_bits = True
dataclass_instances.append(BitsAndBytesConfig)
else:
logger.info("BitsAndBytes is NOT used.")
use_bits = False
logger.info("Parse argument")
tuple_dataclass = parsers.parse_args(
parser=HfArgumentParser(dataclass_instances)
)
classes = post_process(tuple_dataclass,
logger)
load_datasets_args, early_stop_args, shared_args,\
training_args, *other_args = tuple_dataclass
set_seed(training_args.seed)
list_dataclases = list(tuple_dataclass[:4])
for other_arg in other_args:
if use_lora and\
other_arg.__class__.__name__=="LoraConfig":
lora_args = parsers.convert_dataclasses(
other_arg, LoraConfig)
list_dataclases.append(lora_args)
if use_bits and\
isinstance(other_arg, BitsAndBytesConfig):
bits_args = other_arg
list_dataclases.append(bits_args)
tuple_dataclass = tuple(list_dataclases)
del list_dataclases, other_args
training_args: Union[TrainingArguments,Seq2SeqTrainingArguments]
shared_args: SharedArugments
for data_cls in tuple_dataclass:
logger.info("%s", data_cls)
logger.info("Save parsed arguments")
parsers.save_config(data_classes=tuple_dataclass,
file_name=f"{training_args.output_dir}/script_config.json")
if shared_args.do_preprocessing:
dataset = load_dataset(**parsers.module_asdict(
data_class=load_datasets_args))
logger.info("Prepare data for processor")
prepared_dataset = dataset.map(process.prepare_documents_for_processor,
fn_kwargs={"new_special_tokens_path":
shared_args.new_special_tokens_path,
"bos_token": shared_args.bos_token,
"eos_token": shared_args.eos_token,
},
**process.get_map_arguments(dataset,
batched=False,
num_proc=
shared_args.map_num_proc,
)
)
logger.info("Load processor from %s",
shared_args.pretrained_model_name_or_path)
processor = AutoProcessor.from_pretrained(
pretrained_model_name_or_path=shared_args.pretrained_model_name_or_path,
clean_up_tokenization_spaces=True
)
logger.info("Initial Vocab size %s", len(processor.tokenizer))
logger.info("Align processor configuration with data")
prep.prepare_processor(processor=processor,
new_special_tokens_or_path=shared_args.new_special_tokens_path,
height=shared_args.height,
width=shared_args.width,
)
logger.info("Final Vocab size %s", len(processor.tokenizer))
logger.info("Save processor to %s", shared_args.processor_save)
processor.save_pretrained(shared_args.processor_save)
logger.info("Transform image data and tokenize text data")
processed_dataset = prepared_dataset.map(process.transform_and_tokenize,
fn_kwargs={
"processor": processor,
},
**process.get_map_arguments(
prepared_dataset,
batch_size=shared_args.map_batch_size,
writer_batch_size=shared_args.writer_batch_size)
)
logger.info("Save Dataset to %s", shared_args.data_save)
processed_dataset.save_to_disk(shared_args.data_save)
else:
logger.info("Load processor from %s", shared_args.processor_save)
processor = AutoProcessor.from_pretrained(
pretrained_model_name_or_path=shared_args.processor_save,
clean_up_tokenization_spaces=True
)
logger.info("Load model from %s",
shared_args.pretrained_model_name_or_path)
model_kwgs = {
"pretrained_model_name_or_path":
shared_args.pretrained_model_name_or_path,
"attn_implementation": "flash_attention_2",
"cache_dir": shared_args.cache_model,
"trust_remote_code":True,
"quantization_config": bits_args if use_bits else None,
}
try:
model = AutoModelForVision2Seq.from_pretrained(
**model_kwgs
)
except ValueError as valerr:
logger.error(valerr)
model_kwgs.pop("attn_implementation")
model = AutoModelForVision2Seq.from_pretrained(
**model_kwgs
)
logger.info("Align model configuration with processor and data")
prep.prepare_model(model=model, processor=processor,
train_type=train_type,
max_length=shared_args.max_length,
bos_token=shared_args.bos_token,
eos_token=shared_args.eos_token,)
if use_bits:
logger.info("Prepare model for kbit training.")
model = prepare_model_for_kbit_training(model)
if use_lora:
logger.info("Get Peft Model.")
model = get_peft_model(model=model,
peft_config=lora_args)
model.print_trainable_parameters()
logger.info("Configure metrics")
metrics = EvalMetrics(
device=main_device,
tokenizer=processor.tokenizer,
cls=classes,
log_dir=training_args.logging_dir,
new_special_tokens_path=shared_args.new_special_tokens_path,
batched=training_args.batch_eval_metrics
)
logger.info("Set up callbacks")
early_stop = EarlyStoppingCallback(**parsers.module_asdict(
data_class=early_stop_args))
logger.info("Load Dataset from %s", shared_args.data_save)
processed_dataset = DatasetDict.load_from_disk(shared_args.data_save)
keys = get_sorted_key(dataset=processed_dataset)
logger.info("Set dataset format to torch tensor")
processed_dataset.set_format("pt")
logger.info("Create Trainer")
trainer = trainer_class(
model=model,
args=training_args,
train_dataset=processed_dataset[keys[0]],
eval_dataset=processed_dataset[keys[2]],
compute_metrics=metrics.compute_metrics,
preprocess_logits_for_metrics=EvalMetrics.preprocess_logits_for_metrics,
callbacks=[early_stop]
)
logger.info("Distributed type: %s",
trainer.accelerator.distributed_type)
logger.info("The mixed precision is %s",
trainer.accelerator.mixed_precision)
logger.info("Start training")
out = trainer.train()
logger.info("train output: %s", out)
logger.info("Save model to %s", shared_args.model_save)
if trainer.is_fsdp_enabled:
trainer.accelerator.state.\
fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
trainer.save_model(output_dir=shared_args.model_save)
logger.info("Start evaluation")
eval_res = trainer.evaluate(eval_dataset=processed_dataset[keys[1]])
logger.info("The evaluation result is %s", eval_res)
if __name__ == "__main__":
main()
Same problem in case of peft LORA with NO_WRAP. @muellerzr @SunMarc
@muellerzr When auto_find_batch_size=True
, the find_executable_batch_size
function runs _inner_training_loop
multiple times. Since self.accelerator.prepare
wraps (apply auto policy that is undetectable via self._wrap_model
) self.model
before the CUDA out-of-memory error is thrown, the wrapped self.model
in the next iteration causes the FSDP auto-wrapping error.
System Info
acc_cfg.yml:
compute_environment: LOCAL_MACHINE debug: false distributed_type: FSDP downcast_bf16: 'no' enable_cpu_affinity: true fsdp_config: fsdp_activation_checkpointing: true fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP fsdp_backward_prefetch: NO_PREFETCH fsdp_cpu_ram_efficient_loading: true fsdp_forward_prefetch: true fsdp_offload_params: true fsdp_sharding_strategy: FULL_SHARD fsdp_state_dict_type: SHARDED_STATE_DICT fsdp_sync_module_states: true fsdp_use_orig_params: true machine_rank: 0 main_process_ip: 0.0.0.0 main_process_port: 0 main_training_function: main mixed_precision: bf16 num_machines: 3 num_processes: 24 rdzv_backend: etcd-v2 same_network: false tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false
Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
accelerate launch --config_file acc_cfg.yml train.py $TRAINING_ARGS the train.py is any training script that train using transformers.Trainer $TRAINING_ARGS are the TrainingArguments plus some path to data
Expected behavior
Train Paligemma model with FSDP and have PaliGemmaMultiModalProjector wrapped.