unslothai / unsloth

Finetune Llama 3.2, Mistral, Phi & Gemma LLMs 2-5x faster with 80% less memory
https://unsloth.ai
Apache License 2.0
16.43k stars 1.14k forks source link

Implementing weighted loss function #538

Open skerit opened 4 months ago

skerit commented 4 months ago

Mistral has a new finetuner repository where you can assign weights to specific messages, and those will be taken into account when the loss is calculated. I wanted to implement something similar for SFTTrainer, because my dataset contains information that doesn't really make a lot of sense to punish the model for not knowing. But switching completely to DataCollatorForCompletionOnlyLM is also not possible.

My problem is that it's not working at all :sweat_smile:

I might be misunderstanding what Unsloth is doing to the existing trainer. Is it as simple to just create a new trainer class, let it inherit from SFTTrainer with a custom compute_loss function and expect it to run with Unsloth, or is that a no go?

Here's an example dataset to illustrate what I'm trying to achieve

{"pieces":[{"text":"### system:\n","weight":0.5},{"text":"QuirkyQuarters v1.1\n","weight":1},{"text":"\n","weight":0.1},{"text":"### parameters:\n","weight":0.5}]}
{"pieces":[{"text":"### system:\n","weight":0.5},{"text":"QuirkyQuarters v1.2\n","weight":1},{"text":"\n","weight":0.1},{"text":"### parameters:\n","weight":0.5}]}
danielhanchen commented 4 months ago

Oh weighting is possible, but you'll need to add a custom cross entropy loss function ie via removing the LM Head, and putting a custom one

skerit commented 4 months ago

Oh, is that different than implementing a new compute_loss method in a Trainer class?

skerit commented 4 months ago

@danielhanchen Sorry to make you look at some newbie trainer code, but this custom trainer of mine works locally, but always OOMs on Google Collab, when the non-custom trainer does work.


dataset = load_dataset("json", data_files="drive/MyDrive/Unsloth/dataset.jsonl", split = "train")

def generate_and_tokenize_pieces(sample):

    all_input_ids = []
    all_attention_masks = []
    all_weight_ranges = []
    current_length = 0

    for item in sample['pieces']:
        tokenized = tokenizer(item['text'], return_tensors='pt')

        # Get tensor, remove batch dimension
        input_ids = tokenized.input_ids.squeeze()

        # Get tensor, remove batch dimension
        attention_mask = tokenized.attention_mask.squeeze()

        start_idx = current_length
        end_idx = start_idx + len(input_ids) - 1

        all_input_ids.append(input_ids)
        all_attention_masks.append(attention_mask)
        all_weight_ranges.append((start_idx, end_idx, item['weight']))

        # Update current length
        current_length = end_idx + 1

    concatenated_input_ids = torch.cat(all_input_ids, dim=0) if all_input_ids else torch.tensor([], dtype=torch.long)
    concatenated_attention_masks = torch.cat(all_attention_masks, dim=0) if all_attention_masks else torch.tensor([], dtype=torch.long)

    expanded_weight_ranges = torch.tensor([], dtype=torch.long)

    # Convert the weight ranges
    for start_idx, end_idx, weight in all_weight_ranges:
        # Turn the weight into an integer
        weight = int(weight * 100)
        expanded_weight_ranges = torch.cat([expanded_weight_ranges, torch.tensor([weight] * (end_idx - start_idx + 1))])

    # If there are no weight ranges, we return a tensor of ones
    if len(expanded_weight_ranges) == 0:
        expanded_weight_ranges = torch.ones_like(concatenated_input_ids)

    # Pad all the tensors to the same length (max_seq_length)
    concatenated_input_ids = torch.cat([concatenated_input_ids, torch.zeros(max_seq_length - concatenated_input_ids.size(0), dtype=torch.long)])
    concatenated_attention_masks = torch.cat([concatenated_attention_masks, torch.zeros(max_seq_length - concatenated_attention_masks.size(0), dtype=torch.long)])
    expanded_weight_ranges = torch.cat([expanded_weight_ranges, torch.zeros(max_seq_length - expanded_weight_ranges.size(0), dtype=torch.long)])

    return {
        "input_ids": concatenated_input_ids,
        "attention_mask": concatenated_attention_masks,
        "labels": concatenated_input_ids.clone(),
        "weights": expanded_weight_ranges,
    }

tokenized_train_dataset = dataset.map(generate_and_tokenize_pieces, remove_columns=["pieces"])

# My naive custom Trainer class with a custom weighted loss computation
class WeightedLossTrainer(transformers.Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def compute_loss(self, model, inputs, return_outputs=False):

        # Pop off my custom weights property
        weights = inputs.pop("weights")

        # Get the labels
        labels = inputs.get("labels")

        # This always takes a long, long time and OOMs the GPU
        outputs = model(**inputs)

        logits = outputs.get("logits")

        batch_size, seq_len, num_classes = logits.shape

        total_weighted_loss = 0.0
        total_weights = 0.0

        for batch_idx in range(batch_size):
            for seq_idx in range(seq_len):
                weight = weights[batch_idx, seq_idx]
                if weight > 0:  # Only consider tokens that have a weight > 0
                    token_logits = logits[batch_idx, seq_idx]
                    token_label = labels[batch_idx, seq_idx]
                    token_loss = F.cross_entropy(token_logits.unsqueeze(0), token_label.unsqueeze(0), reduction='none')

                    weighted_token_loss = token_loss * (weight / 100)

                    total_weighted_loss += weighted_token_loss.item()
                    total_weights += (weight / 100)

        # Compute the mean loss.
        mean_loss = total_weighted_loss / total_weights if total_weights > 0 else 0.0
        mean_loss = torch.tensor(mean_loss, dtype=torch.float32, device=logits.device, requires_grad=True)

        return (mean_loss, outputs) if return_outputs else mean_loss

training_args = transformers.TrainingArguments(
      per_device_train_batch_size = 2,
      gradient_accumulation_steps = 4,
      warmup_steps = 5,
      max_steps = 60,
      learning_rate = 2e-4,
      fp16 = not is_bfloat16_supported(),
      bf16 = is_bfloat16_supported(),
      logging_steps = 5,
      optim = "adamw_8bit",
      weight_decay = 0.01,
      lr_scheduler_type = "linear",
      seed = 3407,
      output_dir = "outputs",
      remove_unused_columns=False,
)

trainer = WeightedLossTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = tokenized_train_dataset,
    args = training_args,
)

trainer_stats = trainer.train()

Is this because I'm bypassing some kind of Unsloth optimization, or is what I'm doing just ... wrong?

danielhanchen commented 4 months ago

You need to use autocasting ie

with torch.cuda.amp.autocast(dtype = torch.bfloat16):
    model(...)