THUDM / GLM-4

GLM-4 series: Open Multilingual Multimodal Chat LMs | 开源多语言多模态对话模型
Apache License 2.0
4.93k stars 409 forks source link

glm-4-9b-chat 微调时无法执行 Validation #347

Closed Chenhong-Zhang closed 3 months ago

Chenhong-Zhang commented 3 months ago

System Info / 系統信息

安装的包: accelerate 0.31.0 aiohttp 3.9.5 aiosignal 1.3.1 annotated-types 0.7.0 asttokens 2.4.1 async-timeout 4.0.3 attrs 23.2.0 certifi 2024.6.2 charset-normalizer 3.3.2 click 8.1.7 comm 0.2.2 datasets 2.20.0 debugpy 1.6.7 decorator 5.1.1 deepspeed 0.14.4 dill 0.3.8 et-xmlfile 1.1.0 evaluate 0.4.2 exceptiongroup 1.2.0 executing 2.0.1 filelock 3.15.4 frozenlist 1.4.1 fsspec 2024.5.0 hjson 3.1.0 huggingface-hub 0.23.4 idna 3.7 importlib_metadata 8.0.0 ipykernel 6.29.5 ipython 8.26.0 jedi 0.19.1 jieba 0.42.1 Jinja2 3.1.4 joblib 1.4.2 jupyter_client 8.6.2 jupyter_core 5.7.2 markdown-it-py 3.0.0 MarkupSafe 2.1.5 matplotlib-inline 0.1.7 mdurl 0.1.2 mpi4py 3.1.4 mpmath 1.3.0 multidict 6.0.5 multiprocess 0.70.16 nest_asyncio 1.6.0 networkx 3.3 ninja 1.11.1.1 nltk 3.8.1 numpy 1.26.4 nvidia-cublas-cu12 12.1.3.1 nvidia-cuda-cupti-cu12 12.1.105 nvidia-cuda-nvrtc-cu12 12.1.105 nvidia-cuda-runtime-cu12 12.1.105 nvidia-cudnn-cu12 8.9.2.26 nvidia-cufft-cu12 11.0.2.54 nvidia-curand-cu12 10.3.2.106 nvidia-cusolver-cu12 11.4.5.107 nvidia-cusparse-cu12 12.1.0.106 nvidia-ml-py 12.555.43 nvidia-nccl-cu12 2.20.5 nvidia-nvjitlink-cu12 12.5.40 nvidia-nvtx-cu12 12.1.105 openpyxl 3.1.5 packaging 24.1 pandas 2.2.2 parso 0.8.4 peft 0.11.1 pexpect 4.9.0 pickleshare 0.7.5 pip 24.0 platformdirs 4.2.2 prompt_toolkit 3.0.47 psutil 6.0.0 ptyprocess 0.7.0 pure-eval 0.2.2 py-cpuinfo 9.0.0 pyarrow 16.1.0 pyarrow-hotfix 0.6 pydantic 2.7.4 pydantic_core 2.18.4 Pygments 2.18.0 python-dateutil 2.9.0 pytz 2024.1 PyYAML 6.0.1 pyzmq 25.1.2 regex 2024.5.15 requests 2.32.3 rich 13.7.1 rouge-chinese 1.0.3 ruamel.yaml 0.18.6 ruamel.yaml.clib 0.2.8 safetensors 0.4.3 scikit-learn 1.5.0 scipy 1.14.0 setuptools 69.5.1 shellingham 1.5.4 six 1.16.0 stack-data 0.6.2 sympy 1.12.1 threadpoolctl 3.5.0 tiktoken 0.7.0 tokenizers 0.19.1 torch 2.3.1 tornado 6.4.1 tqdm 4.66.4 traitlets 5.14.3 transformers 4.40.0 triton 2.3.1 typer 0.12.3 typing_extensions 4.12.2 tzdata 2024.1 urllib3 2.2.2 wcwidth 0.2.13 wheel 0.43.0 xxhash 3.4.1 yarl 1.9.4 zipp 3.19.2

CUDA Version: 12.1 Linux SLM3090 5.4.0-189-generic

Who can help? / 谁可以帮助到您?

No response

Information / 问题信息

Reproduction / 复现过程

使用Deepspeed ZeRO Stage 3进行微调,deepspeed 的配置:

{ "train_micro_batch_size_per_gpu": "auto", "zero_allow_untested_optimizer": true, "bf16": { "enabled": "auto" }, "optimizer": { "type": "AdamW", "params": { "lr": "auto", "betas": "auto", "eps": "auto", "weight_decay": "auto" } }, "zero_optimization": { "stage": 3, "offload_optimizer": { "device": "cpu", "pin_memory": true }, "offload_param": { "device": "cpu", "pin_memory": true }, "overlap_comm": true, "contiguous_gradients": true, "sub_group_size": 1e9, "reduce_bucket_size": "auto", "stage3_prefetch_bucket_size": "auto", "stage3_param_persistence_threshold": "auto", "stage3_max_live_parameters": 1e9, "stage3_max_reuse_distance": 1e9, "stage3_gather_16bit_weights_on_model_save": true } }

微调时,train正常运行,但是会卡在Validation。

利用Debugger寻找到卡住的地方:

class Seq2SeqTrainer(_Seq2SeqTrainer):
    # Not Support for apex
    def training_step(self, model: nn.Module, inputs: dict[str, Any]) -> torch.Tensor:

        model.train()
        inputs = self._prepare_inputs(inputs)

        with self.compute_loss_context_manager():
            loss = self.compute_loss(model, inputs)

        if self.args.n_gpu > 1:
            loss = loss.mean()
        self.accelerator.backward(loss)
        detached_loss = loss.detach() / self.args.gradient_accumulation_steps
        del inputs
        torch.cuda.empty_cache()
        return detached_loss

    def prediction_step(
            self,
            model: nn.Module,
            inputs: dict[str, Any],
            prediction_loss_only: bool,
            ignore_keys=None,
            **gen_kwargs,
    ) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:

        with torch.no_grad():  # Ensure no gradient computation
            if self.args.predict_with_generate:
                output_ids = inputs.pop('output_ids')
            input_ids = inputs['input_ids']

            loss, generated_tokens, labels = super().prediction_step(
                model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs
            )

            generated_tokens = generated_tokens[:, input_ids.size()[1]:]
            labels = output_ids

            del inputs, input_ids, output_ids
            torch.cuda.empty_cache()

        return loss, generated_tokens, labels

卡住的地方在prediction_step函数下的预测语句:

loss, generated_tokens, labels = super().prediction_step(
                model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs
            )

微调脚本使用的官方脚本,只是对Compute Metrics进行了调整,不应该对这里有影响。 以下是完整的代码:

import os
import json
import dataclasses as dc
import functools
from collections.abc import Callable, Mapping, Sequence
from pathlib import Path
from typing import Annotated, Any, Union
import numpy as np
import ruamel.yaml as yaml
import torch
import typer
from datasets import Dataset, Split
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from peft import PeftConfig, get_peft_config, get_peft_model
from rouge_chinese import Rouge
from torch import nn
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    EvalPrediction,
    GenerationConfig,
    PreTrainedTokenizer,
    Seq2SeqTrainingArguments,
    EarlyStoppingCallback
)
from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainer as _Seq2SeqTrainer
from datasets import load_dataset, DatasetDict, NamedSplit
from typing import Optional

app = typer.Typer(pretty_exceptions_show_locals=False)

class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
    def __call__(self, features, return_tensors=None):
        output_ids = ([feature['output_ids'] for feature in features] if 'output_ids' in features[0].keys() else None)
        if output_ids is not None:
            max_output_length = max(len(out) for out in output_ids)
            if self.pad_to_multiple_of is not None:
                max_output_length = (
                        (
                                max_output_length + self.pad_to_multiple_of - 1) //
                        self.pad_to_multiple_of * self.pad_to_multiple_of
                )
            for feature in features:
                remainder = [self.tokenizer.pad_token_id] * (
                        max_output_length - len(feature['output_ids'])
                )
                if isinstance(feature['output_ids'], list):
                    feature['output_ids'] = feature['output_ids'] + remainder
                else:
                    feature['output_ids'] = np.concatenate(
                        [feature['output_ids'], remainder]
                    ).astype(np.int64)
        return super().__call__(features, return_tensors)

class Seq2SeqTrainer(_Seq2SeqTrainer):
    # Not Support for apex
    def training_step(self, model: nn.Module, inputs: dict[str, Any]) -> torch.Tensor:

        model.train()
        inputs = self._prepare_inputs(inputs)

        with self.compute_loss_context_manager():
            loss = self.compute_loss(model, inputs)

        if self.args.n_gpu > 1:
            loss = loss.mean()
        self.accelerator.backward(loss)
        detached_loss = loss.detach() / self.args.gradient_accumulation_steps
        del inputs
        torch.cuda.empty_cache()
        return detached_loss

    def prediction_step(
            self,
            model: nn.Module,
            inputs: dict[str, Any],
            prediction_loss_only: bool,
            ignore_keys=None,
            **gen_kwargs,
    ) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:

        with torch.no_grad():  # Ensure no gradient computation
            if self.args.predict_with_generate:
                output_ids = inputs.pop('output_ids')
            input_ids = inputs['input_ids']

            loss, generated_tokens, labels = super().prediction_step(
                model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs
            )

            generated_tokens = generated_tokens[:, input_ids.size()[1]:]
            labels = output_ids

            del inputs, input_ids, output_ids
            torch.cuda.empty_cache()

        return loss, generated_tokens, labels

@dc.dataclass
class DataConfig(object):
    train_file: Optional[str] = None
    val_file: Optional[str] = None
    test_file: Optional[str] = None
    num_proc: Optional[int] = None

    @property
    def data_format(self) -> str:
        return Path(self.train_file).suffix

    @property
    def data_files(self) -> dict[NamedSplit, str]:
        return {
            split: data_file
            for split, data_file in zip(
                [Split.TRAIN, Split.VALIDATION, Split.TEST],
                [self.train_file, self.val_file, self.test_file],
            )
            if data_file is not None
        }

@dc.dataclass
class FinetuningConfig(object):
    data_config: DataConfig

    max_input_length: int
    max_output_length: int

    training_args: Seq2SeqTrainingArguments = dc.field(
        default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')
    )
    peft_config: Optional[PeftConfig] = None

    def __post_init__(self):
        if not self.training_args.do_eval or self.data_config.val_file is None:
            self.training_args.do_eval = False
            self.training_args.evaluation_strategy = 'no'
            self.data_config.val_file = None
        else:
            self.training_args.per_device_eval_batch_size = (
                    self.training_args.per_device_eval_batch_size
                    or self.training_args.per_device_train_batch_size
            )

    @classmethod
    def from_dict(cls, **kwargs) -> 'FinetuningConfig':
        training_args = kwargs.get('training_args', None)
        if training_args is not None and not isinstance(
                training_args, Seq2SeqTrainingArguments
        ):
            gen_config = training_args.get('generation_config')
            if not isinstance(gen_config, GenerationConfig):
                training_args['generation_config'] = GenerationConfig(
                    **gen_config
                )
            kwargs['training_args'] = Seq2SeqTrainingArguments(**training_args)

        data_config = kwargs.get('data_config')
        if not isinstance(data_config, DataConfig):
            kwargs['data_config'] = DataConfig(**data_config)

        peft_config = kwargs.get('peft_config', None)
        if peft_config is not None and not isinstance(peft_config, PeftConfig):
            kwargs['peft_config'] = get_peft_config(config_dict=peft_config)
        return cls(**kwargs)

    @classmethod
    def from_file(cls, path: Union[str, Path]) -> 'FinetuningConfig':
        path = Path(path)
        parser = yaml.YAML(typ='safe', pure=True)
        parser.indent(mapping=2, offset=2, sequence=4)
        parser.default_flow_style = False
        kwargs = parser.load(path)
        return cls.from_dict(**kwargs)

def _load_datasets(
        data_dir: str,
        data_format: str,
        data_files: dict[NamedSplit, str],
        num_proc: Optional[int],
) -> DatasetDict:
    if data_format == '.json':
        dataset_dct = load_dataset(
            data_dir,
            data_files=data_files,
            split=None,
            num_proc=num_proc,
        )
    else:
        raise NotImplementedError(f"Cannot load dataset in the '{data_format}' format.")
    return dataset_dct

class DataManager(object):
    def __init__(self, data_dir: str, data_config: DataConfig):
        self._num_proc = data_config.num_proc

        self._dataset_dct = _load_datasets(
            data_dir,
            data_config.data_format,
            data_config.data_files,
            self._num_proc,
        )

    def _get_dataset(self, split: NamedSplit) -> Optional[Dataset]:
        return self._dataset_dct.get(split, None)

    def get_dataset(
            self,
            split: NamedSplit,
            process_fn: Callable[[dict[str, Any]], dict[str, Any]],
            batched: bool = True,
            remove_orig_columns: bool = True,
    ) -> Optional[Dataset]:
        orig_dataset = self._get_dataset(split)
        if orig_dataset is None:
            return

        if remove_orig_columns:
            remove_columns = orig_dataset.column_names
        else:
            remove_columns = None
        return orig_dataset.map(
            process_fn,
            batched=batched,
            remove_columns=remove_columns,
            num_proc=self._num_proc,
        )

def process_message(message):
    if 'tools' in message and message['role'] == 'system':
        for tool in message['tools']:
            parameters = tool['function']['parameters']['properties']
            tool['function']['parameters']['properties'] = \
                {k: v for k, v in parameters.items() if
                 v is not None}
    elif 'tools' in message:
        del message['tools']
    return message

def process_batch(
        batch: Mapping[str, Sequence],
        tokenizer: PreTrainedTokenizer,
        max_input_length: int,
        max_output_length: int,
) -> dict[str, list]:
    batched_conv = batch['messages']
    batched_input_ids = []
    batched_labels = []

    for conv in batched_conv:
        input_ids = [151331, 151333]
        loss_masks = [False, False]
        for message in conv:
            message = process_message(message)
            loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
            new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
            new_loss_masks = [loss_mask_val] * len(new_input_ids)
            input_ids += new_input_ids
            loss_masks += new_loss_masks
        input_ids.append(tokenizer.eos_token_id)
        loss_masks = [False, *loss_masks]
        labels = []
        for input_id, mask in zip(input_ids, loss_masks):
            if mask:
                labels.append(input_id)
            else:
                labels.append(-100)
        max_length = max_input_length + max_output_length + 1
        batched_input_ids.append(input_ids[:max_length])
        batched_labels.append(labels[:max_length])        

    del batched_conv, conv, input_ids, loss_masks, message, new_input_ids, new_loss_masks, labels, input_id, mask
    torch.cuda.empty_cache()

    return {'input_ids': batched_input_ids, 'labels': batched_labels}

def process_batch_eval(
        batch: Mapping[str, Sequence],
        tokenizer: PreTrainedTokenizer,
        max_input_length: int,
        max_output_length: int,
) -> dict[str, list]:
    batched_conv = batch['messages']
    batched_input_ids = []
    batched_output_ids = []

    for conv in batched_conv:
        input_ids = [151331, 151333]
        for message in conv:
            if len(input_ids) >= max_input_length:
                break
            else:
                message = process_message(message)
                new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
                if message['role'] == 'assistant':
                    output_prompt, output_ids = (
                        new_input_ids[:1],
                        new_input_ids[1:],
                    )
                    output_ids.append(tokenizer.eos_token_id)
                    batched_input_ids.append(
                        input_ids[:max_input_length] + output_prompt[:1]
                    )
                    batched_output_ids.append(output_ids[:max_output_length])
                input_ids += new_input_ids

    del batched_conv, conv, input_ids, message, new_input_ids, output_prompt, output_ids
    torch.cuda.empty_cache()

    return {'input_ids': batched_input_ids, 'output_ids': batched_output_ids}

def load_tokenizer_and_model(
        model_dir: str,
        peft_config: Optional[PeftConfig] = None,
):
    tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
    if peft_config is not None:
        model = AutoModelForCausalLM.from_pretrained(
            model_dir,
            trust_remote_code=True,
            empty_init=False,
            use_cache=False,
            torch_dtype=torch.bfloat16  # Must use BFloat 16
        )
        model = get_peft_model(model, peft_config)
        model.print_trainable_parameters()
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_dir,
            trust_remote_code=True,
            empty_init=False,
            use_cache=False,
            torch_dtype=torch.bfloat16
        )
    return tokenizer, model

# def compute_metrics(eval_preds: EvalPrediction, tokenizer):
#     batched_pred_ids, batched_label_ids = eval_preds
#     metrics_dct = {'rouge-1': [], 'rouge-2': [], 'rouge-l': [], 'bleu-4': []}
#     for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids):
#         pred_txt = tokenizer.decode(pred_ids).strip()
#         label_txt = tokenizer.decode(label_ids).strip()
#         pred_tokens = list(jieba.cut(pred_txt))
#         label_tokens = list(jieba.cut(label_txt))
#         rouge = Rouge()
#         scores = rouge.get_scores(' '.join(pred_tokens), ' '.join(label_tokens))
#         for k, v in scores[0].items():
#             metrics_dct[k].append(round(v['f'] * 100, 4))
#         metrics_dct['bleu-4'].append(
#             sentence_bleu([label_tokens], pred_tokens, smoothing_function=SmoothingFunction().method3))
#     return {k: np.mean(v) for k, v in metrics_dct.items()}

def find_and_extract_json(input_string):
    # 查找第一个{和最后一个}的位置
    start_index = input_string.find('{')
    end_index = input_string.rfind('}')

    json_obj = {}

    # 如果找到了有效的括号
    if start_index != -1 and end_index != -1 and end_index > start_index:
        # 提取括号内的内容,包括括号本身
        json_str = input_string[start_index:end_index+1]
        # 尝试解析提取的字符串为JSON
        try:
            json_obj = json.loads(json_str)
        except Exception:
            pass  # 没有提取成功则还是返回空字典
    return json_obj

def calculate_score_by_entity(pred:str, reference:str):
    pred_json = find_and_extract_json(pred)
    reference_json = find_and_extract_json(reference)
    reference_entity = ["".join([x["defect type"] for x in reference_json["defects"]]),"".join([x["defect location"] for x in reference_json["defects"]]), 
                        "".join([x["defect number"] for x in reference_json["defects"]]), "".join([x["defect dimension"] for x in reference_json["defects"]])]
    reference_entity = [' '.join(x.strip()) for x in reference_entity]
    if pred_json:  # 两者都提取成功
        json_score = 1
        try:
            rouge = Rouge(metrics=["rouge-1"])
            pred_entity = ["".join([x["defect type"] for x in pred_json["defects"]]),"".join([x["defect location"] for x in pred_json["defects"]]), 
                           "".join([x["defect number"] for x in pred_json["defects"]]), "".join([x["defect dimension"] for x in pred_json["defects"]])]
            pred_entity = [' '.join(x.strip()) for x in pred_entity]
            identification_score = rouge.get_scores(pred_entity, reference_entity, avg=True)["rouge-1"]["f"]
        except Exception:
            identification_score = 0
    else:
        json_score = 0
        identification_score = 0
    return [json_score, identification_score]    

def compute_metrics(eval_preds, tokenizer):
    batched_pred_ids, batched_label_ids = eval_preds
    batched_pred_txt = tokenizer.batch_decode(batched_pred_ids)  # Decode on batch Level
    batched_label_txt = tokenizer.batch_decode(batched_label_ids)  # Decode on batch Level
    batched_pred_tokens = [' '.join(pred_txt.strip()) for pred_txt in batched_pred_txt]  # Split texts
    batched_label_tokens = [' '.join(label_txt.strip()) for label_txt in batched_label_txt]
    rouge = Rouge(metrics=["rouge-1"])
    scores = np.array([calculate_score_by_entity(x,y) for x, y in zip(batched_pred_txt, batched_label_txt)])
    results = {"rouge-1": rouge.get_scores(batched_pred_tokens, batched_label_tokens, avg=True)["rouge-1"]["f"], 
               "json_identification": scores[:,0].mean(),
               "rouge-1_by_entity": scores[:,1].mean()}
    return results

@app.command()
def main(
        data_dir: Annotated[str, typer.Argument(help='')],
        model_dir: Annotated[
            str,
            typer.Argument(
                help='A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file.'
            ),
        ],
        config_file: Annotated[str, typer.Argument(help='')],
        auto_resume_from_checkpoint: str = typer.Argument(
            default='',
            help='If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training'),
        deepspeed: str = typer.Option("--deepspeed", help="Deepspeed Config dir"),
        local_rank: int = typer.Option(0, "--local_rank", help="Local rank for distributed training")
        ):
    ft_config = FinetuningConfig.from_file(config_file)
    ft_config.training_args.local_rank = local_rank
    ft_config.training_args.deepspeed = deepspeed
    tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
    data_manager = DataManager(data_dir, ft_config.data_config)

    train_dataset = data_manager.get_dataset(
        Split.TRAIN,
        functools.partial(
            process_batch,
            tokenizer=tokenizer,
            max_input_length=ft_config.max_input_length,
            max_output_length=ft_config.max_output_length,
        ),
        batched=True,
    )
    print('train_dataset:', train_dataset)
    val_dataset = data_manager.get_dataset(
        Split.VALIDATION,
        functools.partial(
            process_batch_eval,
            tokenizer=tokenizer,
            max_input_length=ft_config.max_input_length,
            max_output_length=ft_config.max_output_length,
        ),
        batched=True,
    )
    if val_dataset is not None:
        print('val_dataset:', val_dataset)
    test_dataset = data_manager.get_dataset(
        Split.TEST,
        functools.partial(
            process_batch_eval,
            tokenizer=tokenizer,
            max_input_length=ft_config.max_input_length,
            max_output_length=ft_config.max_output_length,
        ),
        batched=True,
    )
    if test_dataset is not None:
        print('test_dataset:', test_dataset)

    model.gradient_checkpointing_enable()
    model.enable_input_require_grads()

    trainer = Seq2SeqTrainer(
        model=model,
        args=ft_config.training_args,
        data_collator=DataCollatorForSeq2Seq(
            tokenizer=tokenizer,
            padding='longest',
            return_tensors='pt',
        ),
        train_dataset=train_dataset,
        eval_dataset=val_dataset.select(list(range(32))),
        compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer),
        callbacks=[EarlyStoppingCallback(early_stopping_patience=20)]
    )

    if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None:
        trainer.train()
    else:
        output_dir = ft_config.training_args.output_dir
        dirlist = os.listdir(output_dir)
        checkpoint_sn = 0
        for checkpoint_str in dirlist:
            if checkpoint_str.find("eckpoint") > 0 and checkpoint_str.find("tmp") == -1:
                checkpoint = int(checkpoint_str.replace("checkpoint-", ""))
                if checkpoint > checkpoint_sn:
                    checkpoint_sn = checkpoint
        if auto_resume_from_checkpoint.upper() == "YES":
            if checkpoint_sn > 0:
                model.gradient_checkpointing_enable()
                model.enable_input_require_grads()
                checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
                print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
                trainer.train(resume_from_checkpoint=checkpoint_directory)
            else:
                trainer.train()
        else:
            if auto_resume_from_checkpoint.isdigit():
                if int(auto_resume_from_checkpoint) > 0:
                    checkpoint_sn = int(auto_resume_from_checkpoint)
                    model.gradient_checkpointing_enable()
                    model.enable_input_require_grads()
                    checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
                    print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
                    trainer.train(resume_from_checkpoint=checkpoint_directory)
            else:
                print(auto_resume_from_checkpoint,
                      "The specified checkpoint sn(" + auto_resume_from_checkpoint + ") has not been saved. Please search for the correct checkpoint in the model output directory")

    if test_dataset is not None:
        trainer.predict(test_dataset)

if __name__ == '__main__':
    app()

Expected behavior / 期待表现

期待微调正常运行,而不是卡在Validation处。

没有尝试过不使用Deepspeed,因为只有Stage 3才能够保证足够的显存进行微调。

机器是8*3090,但是这个问题不论在多卡还是单卡都会出现。

Chenhong-Zhang commented 3 months ago

我发现不是卡住,只是Validation的时间太长,花了半小时。