Open joecummings opened 1 month ago
Was this the kind of thing you had in mind? https://github.com/pytorch/torchtune/blob/1129f9e3a246628c991c246d81dbead62d3437a3/torchtune/modules/rlhf/_generation.py
Granted, there's a couple changes I've been meaning to make (only generating the full mask once, and extending it for each token in the batch, and you'll probably have a more intelligent way of generating the masks themselves).
Was this the kind of thing you had in mind?
1129f9e
/torchtune/modules/rlhf/_generation.py
Yep, this is pretty much it! I take it that you're not utilizing the KV Cache for this generation though, right?
Yep, this is pretty much it! I take it that you're not utilizing the KV Cache for this generation though, right?
Nah. It was also on my TODO list of possible optimizations, and I briefly spoke to Rafi about it, but we agreed it would be kind of a pain in the ass to setup cacheing for custom masks.
Left padded:
My, name, is, Joe
<PAD>, <PAD> Hello, world
<PAD>, <PAD>, <PAD>, Bye
Left padded mask:
1 0 0 0
1 1 0 0
1 1 1 0
1 1 1 1
1 0 0 0
0 1 0 0
0 0 1 0
0 0 1 1
1 0 0 0
0 1 0 0
0 0 1 0
0 0 0 1
Our modules only work with generation under two conditions: batch_size = 1 or every single sample in a batch has the same length. The main culprit is this line of code:
I assume batched generation in the eleuther eval recipe satisfies the latter? I've just got iterative decoding + kv cacheing working for my batched RLHF generation utils - seeing > 10x speedups w/o compile (PPO go brrrr). Can chat about it later today if it's of interest.
Our modules only work with generation under two conditions: batch_size = 1 or every single sample in a batch has the same length. The main culprit is this line of code: https://github.com/pytorch/torchtune/blob/288ff4435b0cf17325b5c3b112f6859a6cdf0ea2/torchtune/modules/transformer.py#L167
For a batch that looks like the following:
A proper mask would look like:
of size [b x s x s], which is [3 x 4 x 4]
This will be a fairly involved change that touches several utils and modules. The general changes needed will be:
mask
parammodel.forward()
This was originally found and reported by @iankur