microsoft / LLMLingua

To speed up LLMs' inference and enhance LLM's perceive of key information, compress the prompt and KV-Cache, which achieves up to 20x compression with minimal performance loss.
https://llmlingua.com/
MIT License
4.42k stars 241 forks source link

Remove Duplicate Declaration of Loss Function #38

Closed Speuce closed 8 months ago

Speuce commented 8 months ago

I noticed there is an unnecessary duplicate declaration of loss_fct here.

Relevant code:

        loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
        shift_logits = response.logits[..., :-1, :].contiguous()
        shift_labels = input_ids[..., past_length + 1 : end].contiguous()
        # Flatten the tokens
        active = (attention_mask[:, past_length:end] == 1)[..., :-1].view(-1)
        active_logits = shift_logits.view(-1, shift_logits.size(-1))[active]
        active_labels = shift_labels.view(-1)[active]
        loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
        loss = loss_fct(active_logits, active_labels)

As you can see, loss_fct is not used before it is declared for a second time, therefore it is safe to remove the first declaration.

iofu728 commented 8 months ago

Hi @Speuce,

Thanks for your help. I'll merge this PR.