Open skerit opened 4 months ago
Oh weighting is possible, but you'll need to add a custom cross entropy loss function ie via removing the LM Head, and putting a custom one
Oh, is that different than implementing a new compute_loss
method in a Trainer class?
@danielhanchen Sorry to make you look at some newbie trainer code, but this custom trainer of mine works locally, but always OOMs on Google Collab, when the non-custom trainer does work.
dataset = load_dataset("json", data_files="drive/MyDrive/Unsloth/dataset.jsonl", split = "train")
def generate_and_tokenize_pieces(sample):
all_input_ids = []
all_attention_masks = []
all_weight_ranges = []
current_length = 0
for item in sample['pieces']:
tokenized = tokenizer(item['text'], return_tensors='pt')
# Get tensor, remove batch dimension
input_ids = tokenized.input_ids.squeeze()
# Get tensor, remove batch dimension
attention_mask = tokenized.attention_mask.squeeze()
start_idx = current_length
end_idx = start_idx + len(input_ids) - 1
all_input_ids.append(input_ids)
all_attention_masks.append(attention_mask)
all_weight_ranges.append((start_idx, end_idx, item['weight']))
# Update current length
current_length = end_idx + 1
concatenated_input_ids = torch.cat(all_input_ids, dim=0) if all_input_ids else torch.tensor([], dtype=torch.long)
concatenated_attention_masks = torch.cat(all_attention_masks, dim=0) if all_attention_masks else torch.tensor([], dtype=torch.long)
expanded_weight_ranges = torch.tensor([], dtype=torch.long)
# Convert the weight ranges
for start_idx, end_idx, weight in all_weight_ranges:
# Turn the weight into an integer
weight = int(weight * 100)
expanded_weight_ranges = torch.cat([expanded_weight_ranges, torch.tensor([weight] * (end_idx - start_idx + 1))])
# If there are no weight ranges, we return a tensor of ones
if len(expanded_weight_ranges) == 0:
expanded_weight_ranges = torch.ones_like(concatenated_input_ids)
# Pad all the tensors to the same length (max_seq_length)
concatenated_input_ids = torch.cat([concatenated_input_ids, torch.zeros(max_seq_length - concatenated_input_ids.size(0), dtype=torch.long)])
concatenated_attention_masks = torch.cat([concatenated_attention_masks, torch.zeros(max_seq_length - concatenated_attention_masks.size(0), dtype=torch.long)])
expanded_weight_ranges = torch.cat([expanded_weight_ranges, torch.zeros(max_seq_length - expanded_weight_ranges.size(0), dtype=torch.long)])
return {
"input_ids": concatenated_input_ids,
"attention_mask": concatenated_attention_masks,
"labels": concatenated_input_ids.clone(),
"weights": expanded_weight_ranges,
}
tokenized_train_dataset = dataset.map(generate_and_tokenize_pieces, remove_columns=["pieces"])
# My naive custom Trainer class with a custom weighted loss computation
class WeightedLossTrainer(transformers.Trainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def compute_loss(self, model, inputs, return_outputs=False):
# Pop off my custom weights property
weights = inputs.pop("weights")
# Get the labels
labels = inputs.get("labels")
# This always takes a long, long time and OOMs the GPU
outputs = model(**inputs)
logits = outputs.get("logits")
batch_size, seq_len, num_classes = logits.shape
total_weighted_loss = 0.0
total_weights = 0.0
for batch_idx in range(batch_size):
for seq_idx in range(seq_len):
weight = weights[batch_idx, seq_idx]
if weight > 0: # Only consider tokens that have a weight > 0
token_logits = logits[batch_idx, seq_idx]
token_label = labels[batch_idx, seq_idx]
token_loss = F.cross_entropy(token_logits.unsqueeze(0), token_label.unsqueeze(0), reduction='none')
weighted_token_loss = token_loss * (weight / 100)
total_weighted_loss += weighted_token_loss.item()
total_weights += (weight / 100)
# Compute the mean loss.
mean_loss = total_weighted_loss / total_weights if total_weights > 0 else 0.0
mean_loss = torch.tensor(mean_loss, dtype=torch.float32, device=logits.device, requires_grad=True)
return (mean_loss, outputs) if return_outputs else mean_loss
training_args = transformers.TrainingArguments(
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
warmup_steps = 5,
max_steps = 60,
learning_rate = 2e-4,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
logging_steps = 5,
optim = "adamw_8bit",
weight_decay = 0.01,
lr_scheduler_type = "linear",
seed = 3407,
output_dir = "outputs",
remove_unused_columns=False,
)
trainer = WeightedLossTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = tokenized_train_dataset,
args = training_args,
)
trainer_stats = trainer.train()
Is this because I'm bypassing some kind of Unsloth optimization, or is what I'm doing just ... wrong?
You need to use autocasting ie
with torch.cuda.amp.autocast(dtype = torch.bfloat16):
model(...)
Mistral has a new finetuner repository where you can assign weights to specific messages, and those will be taken into account when the loss is calculated. I wanted to implement something similar for SFTTrainer, because my dataset contains information that doesn't really make a lot of sense to punish the model for not knowing. But switching completely to
DataCollatorForCompletionOnlyLM
is also not possible.My problem is that it's not working at all :sweat_smile:
I might be misunderstanding what Unsloth is doing to the existing trainer. Is it as simple to just create a new trainer class, let it inherit from
SFTTrainer
with a customcompute_loss
function and expect it to run with Unsloth, or is that a no go?Here's an example dataset to illustrate what I'm trying to achieve