pytorch / torchtune

A Native-PyTorch Library for LLM Fine-tuning
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
3.92k stars 354 forks source link

Fix generation for bsz > 1 #1250

Open joecummings opened 1 month ago

joecummings commented 1 month ago

Our modules only work with generation under two conditions: batch_size = 1 or every single sample in a batch has the same length. The main culprit is this line of code: https://github.com/pytorch/torchtune/blob/288ff4435b0cf17325b5c3b112f6859a6cdf0ea2/torchtune/modules/transformer.py#L167

For a batch that looks like the following:

My, name, is, Joe
Hello, world, <PAD>, <PAD>
Bye, <PAD>, <PAD>, <PAD>

A proper mask would look like:

1 0 0 0
1 1 0 0 
1 1 1 0
1 1 1 1

1 0 0 0
1 1 0 0
0 0 0 0
0 0 0 0

1 0 0 0
0 0 0 0
0 0 0 0
0 0 0 0

of size [b x s x s], which is [3 x 4 x 4]

This will be a fairly involved change that touches several utils and modules. The general changes needed will be:


This was originally found and reported by @iankur

SalmanMohammadi commented 1 month ago

Was this the kind of thing you had in mind? https://github.com/pytorch/torchtune/blob/1129f9e3a246628c991c246d81dbead62d3437a3/torchtune/modules/rlhf/_generation.py

Granted, there's a couple changes I've been meaning to make (only generating the full mask once, and extending it for each token in the batch, and you'll probably have a more intelligent way of generating the masks themselves).

joecummings commented 1 month ago

Was this the kind of thing you had in mind? 1129f9e/torchtune/modules/rlhf/_generation.py

Yep, this is pretty much it! I take it that you're not utilizing the KV Cache for this generation though, right?

SalmanMohammadi commented 1 month ago

Yep, this is pretty much it! I take it that you're not utilizing the KV Cache for this generation though, right?

Nah. It was also on my TODO list of possible optimizations, and I briefly spoke to Rafi about it, but we agreed it would be kind of a pain in the ass to setup cacheing for custom masks.

joecummings commented 3 weeks ago

Left padded:

My, name, is, Joe
<PAD>, <PAD> Hello, world 
<PAD>, <PAD>, <PAD>, Bye

Left padded mask:

1 0 0 0
1 1 0 0
1 1 1 0
1 1 1 1

1 0 0 0
0 1 0 0
0 0 1 0 
0 0 1 1

1 0 0 0
0 1 0 0
0 0 1 0
0 0 0 1
SalmanMohammadi commented 3 weeks ago

Our modules only work with generation under two conditions: batch_size = 1 or every single sample in a batch has the same length. The main culprit is this line of code:

I assume batched generation in the eleuther eval recipe satisfies the latter? I've just got iterative decoding + kv cacheing working for my batched RLHF generation utils - seeing > 10x speedups w/o compile (PPO go brrrr). Can chat about it later today if it's of interest.