LLM360 / amber-train

Pre-training code for Amber 7B LLM
Apache License 2.0
141 stars 15 forks source link

Aligning logits with labels through two shifts? #6

Open Xlun opened 3 months ago

Xlun commented 3 months ago

在 main.py中数据准备时:

def collate_fn(examples, device):
    token_ids = torch.tensor(
        [example['token_ids'] for example in examples], device=device)
    return **{'input_ids': token_ids[:, :-1], 'labels': token_ids[:, 1:]}**

def train_chunk(.......):
..........
batch = collate_fn(
            examples=examples[i:i+per_device_batch_size], device=fabric.device)
input_ids, labels = batch['input_ids'], batch['labels']

在 modeling_llama.py 中loss计算时:

class LlamaForCausalLM(LlamaPreTrainedModel):
....................
        if labels is not None:
            # Shift so that tokens < n predict n
            **shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()**

为什么在模型数据sample输入时进行了预测和真实值之间的位移对齐,在模型中loss计算时还进行了一次位移对齐?