Open YasmineXXX opened 2 months ago
class AllGather(torch.autograd.Function):
"""An autograd function that performs allgather on a tensor."""
@staticmethod
def forward(ctx, tensor, world_size, rank):
output = [torch.empty_like(tensor) for _ in range(world_size)]
torch.distributed.all_gather(output, tensor)
ctx.rank = rank
ctx.batch_size = tensor.shape[0]
return torch.cat(output, dim=0)
@staticmethod
def backward(ctx, grad_output):
return (
grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)],
None, None
)
class Seq2SeqTrainer(PushToMsHubMixin, SwiftMixin, HfSeq2SeqTrainer):
def compute_loss(self, model, inputs, return_outputs=None):
if not hasattr(self, '_custom_metrics'):
self._custom_metrics = {}
labels = None
loss_scale = None
if 'loss_scale' in inputs:
labels = inputs.pop('labels')
loss_scale = inputs.pop('loss_scale')
if self.label_smoother is not None and 'labels' in inputs:
labels = inputs.pop('labels')
query_text = inputs.pop('query_text', None)
gt_video_id = inputs.pop('gt_video_id', None)
video_id = inputs.pop('video_id', None)
outputs = model(**inputs)
if loss_scale is not None:
outputs['loss'] = self.compute_scaled_loss(labels, outputs.logits, loss_scale)
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
if labels is not None and loss_scale is None:
unwrapped_model = unwrap_model(model)
if is_peft_available() and isinstance(unwrapped_model, PeftModel):
model_name = unwrapped_model.base_model.model._get_name()
else:
model_name = unwrapped_model._get_name()
if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
loss = self.label_smoother(outputs, labels, shift_labels=True)
else:
loss = self.label_smoother(outputs, labels)
else:
loss = outputs['loss'] if isinstance(outputs, dict) else outputs[0]
if self.sequence_parallel_size > 1:
from swift.trainers.xtuner import reduce_xtuner_sequence_parallel_loss
loss = reduce_xtuner_sequence_parallel_loss(loss, labels)
if labels is None:
labels = inputs['labels']
if self.is_encoder_decoder:
preds = outputs.logits.argmax(dim=2)[..., :]
labels = labels[..., :]
else:
preds = outputs.logits.argmax(dim=2)[..., :-1]
labels = labels[..., 1:]
masks = labels != -100
acc_strategy = getattr(self.args, 'acc_strategy', 'token')
acc: Optional[Tensor] = None
if self.state.global_step % self.sft_args.acc_steps == 0:
if preds.shape != labels.shape:
pass
elif acc_strategy == 'sentence':
acc_list = []
for i, m in enumerate(masks):
acc_list.append(torch.all(preds[i, m] == labels[i, m]).to(torch.int64).item())
acc = torch.tensor(acc_list, device=preds.device).float().mean()
else:
if use_torchacc():
ta_trim_graph()
preds = preds.to('cpu')
masks = masks.to('cpu')
labels = labels.to('cpu')
acc = (torch.masked_select(preds, masks) == torch.masked_select(labels, masks)).float().mean()
if model.training and acc is not None:
if 'acc' not in self._custom_metrics:
self._custom_metrics['acc'] = self._acc
self._custom_metrics['acc'] = self._custom_metrics['acc'] + acc / self.args.gradient_accumulation_steps
# 对比损失begins*********************************
logits = outputs.logits
# 根据 inputs['labels'] 的 -9 位置的值来确定正例/负例
labels = inputs['labels']
pos = (labels[:, -9] != 29900).long() # 如果 -9 位置的值为 29900,则 pos 为 0表示负例,否则为 1表示正例
# AllGather用于聚合同一个batch不同gpu上的变量,确保反向传播的正确进行
gathered_logits = AllGather.apply(logits, dist.get_world_size(), dist.get_rank())
gathered_pos = AllGather.apply(pos, dist.get_world_size(), dist.get_rank())
print(f"gathered_logits shape: {gathered_logits.shape}")
print(f"gathered_pos shape: {gathered_pos.shape}")
def get_rele_scores(logits, tokenizer, label_ids, prefix):
final_level_idx = -9
selected_logits = logits[:, final_level_idx, :]
rele_logits = selected_logits[:, label_ids].float()
probs = torch.nn.functional.softmax(rele_logits, dim=-1)
rele_scores = (probs * torch.tensor([0.0, 1.0], device=logits.device)).sum(dim=-1)
for i, score in enumerate(rele_scores):
print(f"{prefix}_rele_score_sample_{i}: {score.item()}")
return rele_scores
# Compute InfoNCE loss
pos_mask = gathered_pos == 1
neg_mask = gathered_pos == 0
pos_logits = gathered_logits[pos_mask] # 正例logits
neg_logits = gathered_logits[neg_mask] # 负例logits
if pos_logits.size(0) == 0 or neg_logits.size(0) == 0: # 若当前batch中没有正例/没有负例,跳过
infonce_loss = torch.tensor(0.0, device=logits.device)
else:
tokenizer = self.tokenizer
label_ids = [x[-1] for x in tokenizer(['0', '1']).input_ids] # 找到logits中0/1token对应的prob
pos_score = get_rele_scores(pos_logits, tokenizer, label_ids, "pos")
neg_scores = get_rele_scores(neg_logits, tokenizer, label_ids, "neg")
temperature = 0.3
# 处理 pos_score 和 neg_scores 维度不同的情况
pos_score_exp = pos_score.unsqueeze(1) # [N_pos] -> [N_pos, 1]
neg_scores_exp = neg_scores.unsqueeze(0) # [N_neg] -> [1, N_neg]
# 计算 InfoNCE 损失
exp_pos = torch.exp(pos_score_exp / temperature) # [N_pos, 1]
exp_neg = torch.exp(neg_scores_exp / temperature) # [1, N_neg]
denominator = exp_pos + exp_neg.sum(dim=1, keepdim=True) # [N_pos, 1]
infonce_loss = -torch.log(exp_pos / denominator).mean()
# Combine original loss and InfoNCE loss
alpha = 1.0
combined_loss = loss.mean() + alpha * infonce_loss
print(f"loss: {loss.item()}, infonce_loss: {infonce_loss.item()}")
# return (loss, outputs) if return_outputs else loss
return (combined_loss, outputs) if return_outputs else combined_loss
以上是目前集成的一版对比损失,但是在运行第一个batch时就会导致死锁,相关报错是:
Train: 0%| | 0/8300 [00:00<?, ?it/s]WARNING:transformers_modules.InternVL2-4B.modeling_phi3:`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
WARNING:transformers_modules.InternVL2-4B.modeling_phi3:`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
gathered_logits shape: torch.Size([4, 906, 32020])
gathered_pos shape: torch.Size([4])
gathered_logits shape: torch.Size([4, 966, 32020])
gathered_pos shape: torch.Size([4])
dynamic ViT batch size: 2, images per sample: 2.0, dynamic token length: 902
WARNING:transformers_modules.InternVL2-4B.modeling_phi3:`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
gathered_logits shape: torch.Size([4, 902, 32020])
gathered_pos shape: torch.Size([4])
WARNING:transformers_modules.InternVL2-4B.modeling_phi3:`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
gathered_logits shape: torch.Size([4, 902, 32020])
gathered_pos shape: torch.Size([4])
[E ProcessGroupNCCL.cpp:475] [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=4, OpType=ALLGATHER, NumelIn=30931320, NumelOut=123725280, Timeout(ms)=1800000) ran for 1800304 milliseconds before timing out.
loss: 7.311283111572266, infonce_loss: 0.0
[E ProcessGroupNCCL.cpp:489] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:495] To avoid data inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:916] [Rank 2] NCCL watchdog thread terminated with exception: [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=4, OpType=ALLGATHER, NumelIn=30931320, NumelOut=123725280, Timeout(ms)=1800000) ran for 1800304 milliseconds before timing out.
terminate called after throwing an instance of 'std::runtime_error'
what(): [Rank 2] NCCL watchdog thread terminated with exception: [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=4, OpType=ALLGATHER, NumelIn=30931320, NumelOut=123725280, Timeout(ms)=1800000) ran for 1800304 milliseconds before timing out.
[2024-09-27 15:17:22,821] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 4059737 closing signal SIGTERM
[2024-09-27 15:17:22,822] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 4059738 closing signal SIGTERM
[2024-09-27 15:17:22,823] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 4059740 closing signal SIGTERM
[2024-09-27 15:17:23,805] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: -6) local_rank: 2 (pid: 4059739) of binary: /opt/conda/bin/python
Traceback (most recent call last):
File "/opt/conda/bin/torchrun", line 8, in <module>
sys.exit(main())
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
return f(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/run.py", line 806, in main
run(args)
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/run.py", line 797, in run
elastic_launch(
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
========================================================
/mnt/code/swift_internvl2_cot_cl/swift/cli/sft.py FAILED
--------------------------------------------------------
Failures:
<NO_OTHER_FAILURES>
--------------------------------------------------------
Root Cause (first observed failure):
[0]:
time : 2024-09-27_15:17:22
host : 9129914748d9
rank : 2 (local_rank: 2)
exitcode : -6 (pid: 4059739)
error_file: <N/A>
traceback : Signal 6 (SIGABRT) received by PID 4059739
========================================================
想问下代码实现地哪里有问题呢?
py-spy dump --pid <...>
查看卡在forward还是backward
开启了zero2/zero3不
py-spy 因为容器权限问题暂时无法使用,但是debug可知在forward过程中就卡住了;使用了zero2,debug配置为:
{
"name": "debug_cl",
"type": "debugpy",
"request": "launch",
"module": "torch.distributed.run",
"console": "integratedTerminal",
"justMyCode": false,
"env": {
"CUDA_VISIBLE_DEVICES": "0,1",
"PYTHONPATH": "./"
},
"args": [
"--master_port", "29510",
"--nproc_per_node", "2",
"swift/cli/main.py", // 你的主程序文件
"sft",
"--sft_type", "lora",
"--model_type", "internvl2-4b",
"--custom_train_dataset_path", "/mnt/data/msrvtt/llm_data/msrvtt_cot_3500.jsonl",
"--resume_from_checkpoint", "/mnt/code/swift_all/output/internvl2-4b/v2-20240827-161031/checkpoint-39827",
"--resume_only_model", "True",
"--save_strategy", "epoch",
"--num_train_epochs", "10",
"--save_total_limit", "10000",
"--ddp_find_unused_parameters", "true",
"--max_length", "4096",
"--lora_rank", "8",
"--lora_alpha", "32",
"--lora_dropout_p", "0.05",
"--lora_target_modules", "ALL",
"--gradient_checkpointing", "true",
"--batch_size", "1",
"--weight_decay", "0.01",
"--learning_rate", "5e-5",
"--save_steps", "3000",
"--gradient_accumulation_steps", "1",
"--max_grad_norm", "0.5",
"--warmup_ratio", "0.03",
"--dtype", "bf16",
"--deepspeed", "default-zero2"
]
},
在哪一行知道吗
你要不要升级一下ms-swift试试
Describe the feature 提供多种损失函数的sft训练,比如对比损失
Paste any useful information sft时,除了交叉熵损失,有时需要针对某个特定token计算对比损失、pairloss等等,可否集成这样一个功能呢?
Additional context