modelscope / ms-swift

Use PEFT or Full-parameter to finetune 400+ LLMs or 100+ MLLMs. (LLM: Qwen2.5, Llama3.2, GLM4, Internlm2.5, Yi1.5, Mistral, Baichuan2, DeepSeek, Gemma2, ...; MLLM: Qwen2-VL, Qwen2-Audio, Llama3.2-Vision, Llava, InternVL2, MiniCPM-V-2.6, GLM4v, Xcomposer2.5, Yi-VL, DeepSeek-VL, Phi3.5-Vision, ...)
https://swift.readthedocs.io/zh-cn/latest/Instruction/index.html
Apache License 2.0
4.37k stars 385 forks source link

集成多类损失函数的sft训练(如对比损失) #2117

Open YasmineXXX opened 2 months ago

YasmineXXX commented 2 months ago

Describe the feature 提供多种损失函数的sft训练,比如对比损失

Paste any useful information sft时,除了交叉熵损失,有时需要针对某个特定token计算对比损失、pairloss等等,可否集成这样一个功能呢?

Additional context

YasmineXXX commented 1 month ago
class AllGather(torch.autograd.Function):
    """An autograd function that performs allgather on a tensor."""

    @staticmethod
    def forward(ctx, tensor, world_size, rank):
        output = [torch.empty_like(tensor) for _ in range(world_size)]
        torch.distributed.all_gather(output, tensor)
        ctx.rank = rank
        ctx.batch_size = tensor.shape[0]
        return torch.cat(output, dim=0)

    @staticmethod
    def backward(ctx, grad_output):
        return (
            grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)],
            None, None
        )

class Seq2SeqTrainer(PushToMsHubMixin, SwiftMixin, HfSeq2SeqTrainer):
    def compute_loss(self, model, inputs, return_outputs=None):
        if not hasattr(self, '_custom_metrics'):
            self._custom_metrics = {}

        labels = None
        loss_scale = None
        if 'loss_scale' in inputs:
            labels = inputs.pop('labels')
            loss_scale = inputs.pop('loss_scale')

        if self.label_smoother is not None and 'labels' in inputs:
            labels = inputs.pop('labels')

        query_text = inputs.pop('query_text', None)
        gt_video_id = inputs.pop('gt_video_id', None)
        video_id = inputs.pop('video_id', None)

        outputs = model(**inputs)
        if loss_scale is not None:
            outputs['loss'] = self.compute_scaled_loss(labels, outputs.logits, loss_scale)

        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None and loss_scale is None:
            unwrapped_model = unwrap_model(model)
            if is_peft_available() and isinstance(unwrapped_model, PeftModel):
                model_name = unwrapped_model.base_model.model._get_name()
            else:
                model_name = unwrapped_model._get_name()
            if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                loss = self.label_smoother(outputs, labels, shift_labels=True)
            else:
                loss = self.label_smoother(outputs, labels)
        else:
            loss = outputs['loss'] if isinstance(outputs, dict) else outputs[0]

        if self.sequence_parallel_size > 1:
            from swift.trainers.xtuner import reduce_xtuner_sequence_parallel_loss
            loss = reduce_xtuner_sequence_parallel_loss(loss, labels)

        if labels is None:
            labels = inputs['labels']
        if self.is_encoder_decoder:
            preds = outputs.logits.argmax(dim=2)[..., :]
            labels = labels[..., :]
        else:
            preds = outputs.logits.argmax(dim=2)[..., :-1]
            labels = labels[..., 1:]

        masks = labels != -100
        acc_strategy = getattr(self.args, 'acc_strategy', 'token')
        acc: Optional[Tensor] = None

        if self.state.global_step % self.sft_args.acc_steps == 0:
            if preds.shape != labels.shape:
                pass
            elif acc_strategy == 'sentence':
                acc_list = []
                for i, m in enumerate(masks):
                    acc_list.append(torch.all(preds[i, m] == labels[i, m]).to(torch.int64).item())
                acc = torch.tensor(acc_list, device=preds.device).float().mean()
            else:
                if use_torchacc():
                    ta_trim_graph()
                    preds = preds.to('cpu')
                    masks = masks.to('cpu')
                    labels = labels.to('cpu')
                acc = (torch.masked_select(preds, masks) == torch.masked_select(labels, masks)).float().mean()
            if model.training and acc is not None:
                if 'acc' not in self._custom_metrics:
                    self._custom_metrics['acc'] = self._acc
                self._custom_metrics['acc'] = self._custom_metrics['acc'] + acc / self.args.gradient_accumulation_steps

        # 对比损失begins*********************************
        logits = outputs.logits

        # 根据 inputs['labels'] 的 -9 位置的值来确定正例/负例
        labels = inputs['labels']
        pos = (labels[:, -9] != 29900).long()  # 如果 -9 位置的值为 29900,则 pos 为 0表示负例,否则为 1表示正例

        # AllGather用于聚合同一个batch不同gpu上的变量,确保反向传播的正确进行
        gathered_logits = AllGather.apply(logits, dist.get_world_size(), dist.get_rank())
        gathered_pos = AllGather.apply(pos, dist.get_world_size(), dist.get_rank())
        print(f"gathered_logits shape: {gathered_logits.shape}")
        print(f"gathered_pos shape: {gathered_pos.shape}")

        def get_rele_scores(logits, tokenizer, label_ids, prefix):
            final_level_idx = -9
            selected_logits = logits[:, final_level_idx, :]
            rele_logits = selected_logits[:, label_ids].float()
            probs = torch.nn.functional.softmax(rele_logits, dim=-1)
            rele_scores = (probs * torch.tensor([0.0, 1.0], device=logits.device)).sum(dim=-1)

            for i, score in enumerate(rele_scores):
                print(f"{prefix}_rele_score_sample_{i}: {score.item()}")
            return rele_scores

        # Compute InfoNCE loss
        pos_mask = gathered_pos == 1
        neg_mask = gathered_pos == 0
        pos_logits = gathered_logits[pos_mask] # 正例logits
        neg_logits = gathered_logits[neg_mask] # 负例logits

        if pos_logits.size(0) == 0 or neg_logits.size(0) == 0: # 若当前batch中没有正例/没有负例,跳过
            infonce_loss = torch.tensor(0.0, device=logits.device)
        else:
            tokenizer = self.tokenizer
            label_ids = [x[-1] for x in tokenizer(['0', '1']).input_ids] # 找到logits中0/1token对应的prob
            pos_score = get_rele_scores(pos_logits, tokenizer, label_ids, "pos")
            neg_scores = get_rele_scores(neg_logits, tokenizer, label_ids, "neg")
            temperature = 0.3

            # 处理 pos_score 和 neg_scores 维度不同的情况
            pos_score_exp = pos_score.unsqueeze(1)  # [N_pos] -> [N_pos, 1]
            neg_scores_exp = neg_scores.unsqueeze(0)  # [N_neg] -> [1, N_neg]

            # 计算 InfoNCE 损失
            exp_pos = torch.exp(pos_score_exp / temperature)  # [N_pos, 1]
            exp_neg = torch.exp(neg_scores_exp / temperature)  # [1, N_neg]
            denominator = exp_pos + exp_neg.sum(dim=1, keepdim=True)  # [N_pos, 1]

            infonce_loss = -torch.log(exp_pos / denominator).mean()

        # Combine original loss and InfoNCE loss
        alpha = 1.0
        combined_loss = loss.mean() + alpha * infonce_loss
        print(f"loss: {loss.item()}, infonce_loss: {infonce_loss.item()}")        
        # return (loss, outputs) if return_outputs else loss
        return (combined_loss, outputs) if return_outputs else combined_loss

以上是目前集成的一版对比损失,但是在运行第一个batch时就会导致死锁,相关报错是:

Train:   0%|          | 0/8300 [00:00<?, ?it/s]WARNING:transformers_modules.InternVL2-4B.modeling_phi3:`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
WARNING:transformers_modules.InternVL2-4B.modeling_phi3:`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
gathered_logits shape: torch.Size([4, 906, 32020])
gathered_pos shape: torch.Size([4])
gathered_logits shape: torch.Size([4, 966, 32020])
gathered_pos shape: torch.Size([4])
dynamic ViT batch size: 2, images per sample: 2.0, dynamic token length: 902
WARNING:transformers_modules.InternVL2-4B.modeling_phi3:`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
gathered_logits shape: torch.Size([4, 902, 32020])
gathered_pos shape: torch.Size([4])
WARNING:transformers_modules.InternVL2-4B.modeling_phi3:`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
gathered_logits shape: torch.Size([4, 902, 32020])
gathered_pos shape: torch.Size([4])
[E ProcessGroupNCCL.cpp:475] [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=4, OpType=ALLGATHER, NumelIn=30931320, NumelOut=123725280, Timeout(ms)=1800000) ran for 1800304 milliseconds before timing out.
loss: 7.311283111572266, infonce_loss: 0.0
[E ProcessGroupNCCL.cpp:489] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:495] To avoid data inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:916] [Rank 2] NCCL watchdog thread terminated with exception: [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=4, OpType=ALLGATHER, NumelIn=30931320, NumelOut=123725280, Timeout(ms)=1800000) ran for 1800304 milliseconds before timing out.
terminate called after throwing an instance of 'std::runtime_error'
  what():  [Rank 2] NCCL watchdog thread terminated with exception: [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=4, OpType=ALLGATHER, NumelIn=30931320, NumelOut=123725280, Timeout(ms)=1800000) ran for 1800304 milliseconds before timing out.
[2024-09-27 15:17:22,821] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 4059737 closing signal SIGTERM
[2024-09-27 15:17:22,822] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 4059738 closing signal SIGTERM
[2024-09-27 15:17:22,823] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 4059740 closing signal SIGTERM
[2024-09-27 15:17:23,805] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: -6) local_rank: 2 (pid: 4059739) of binary: /opt/conda/bin/python
Traceback (most recent call last):
  File "/opt/conda/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/run.py", line 806, in main
    run(args)
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/run.py", line 797, in run
    elastic_launch(
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
========================================================
/mnt/code/swift_internvl2_cot_cl/swift/cli/sft.py FAILED
--------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
--------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-09-27_15:17:22
  host      : 9129914748d9
  rank      : 2 (local_rank: 2)
  exitcode  : -6 (pid: 4059739)
  error_file: <N/A>
  traceback : Signal 6 (SIGABRT) received by PID 4059739
========================================================

想问下代码实现地哪里有问题呢?

Jintao-Huang commented 1 month ago

py-spy dump --pid <...>

查看卡在forward还是backward

开启了zero2/zero3不

YasmineXXX commented 1 month ago

py-spy 因为容器权限问题暂时无法使用,但是debug可知在forward过程中就卡住了;使用了zero2,debug配置为:

        {
            "name": "debug_cl",
            "type": "debugpy",
            "request": "launch",
            "module": "torch.distributed.run",
            "console": "integratedTerminal",
            "justMyCode": false,
            "env": {
                "CUDA_VISIBLE_DEVICES": "0,1",
                "PYTHONPATH": "./"
            },
            "args": [
                "--master_port", "29510",
                "--nproc_per_node", "2",
                "swift/cli/main.py",  // 你的主程序文件
                "sft",
                "--sft_type", "lora",
                "--model_type", "internvl2-4b",
                "--custom_train_dataset_path", "/mnt/data/msrvtt/llm_data/msrvtt_cot_3500.jsonl",
                "--resume_from_checkpoint", "/mnt/code/swift_all/output/internvl2-4b/v2-20240827-161031/checkpoint-39827",
                "--resume_only_model", "True",
                "--save_strategy", "epoch",
                "--num_train_epochs", "10",
                "--save_total_limit", "10000",
                "--ddp_find_unused_parameters", "true",
                "--max_length", "4096",
                "--lora_rank", "8",
                "--lora_alpha", "32",
                "--lora_dropout_p", "0.05",
                "--lora_target_modules", "ALL",
                "--gradient_checkpointing", "true",
                "--batch_size", "1",
                "--weight_decay", "0.01",
                "--learning_rate", "5e-5",
                "--save_steps", "3000",
                "--gradient_accumulation_steps", "1",
                "--max_grad_norm", "0.5",
                "--warmup_ratio", "0.03",
                "--dtype", "bf16",
                "--deepspeed", "default-zero2"
            ]
        },        
Jintao-Huang commented 1 month ago

在哪一行知道吗

Jintao-Huang commented 1 month ago

你要不要升级一下ms-swift试试