pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.35k stars 440 forks source link

Update torchtune generation to be more flexible #1970

Closed RylanC24 closed 2 weeks ago

RylanC24 commented 2 weeks ago

Summary: The existing softmax sampling trick implementation in the torchtune generator is not flexible enough to deal with vocab pruned models (when the number of logits produced does not match the size of the embedding layer).

This is an unnecessary limitation and is easy to fix if we simply create the q tensor to match the size of the logits tensor instead of the embedding layer.

Differential Revision: D65480353

pytorch-bot[bot] commented 2 weeks ago

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1970

Note: Links to docs will display an error until the docs builds have been completed.

:white_check_mark: No Failures

As of commit 344e99f28553eeaa7ffcbf72669e68f1cac1471b with merge base 7bfb3336446f0d874ab5d4595249839b735b7076 (image): :green_heart: Looks good so far! There are no failures yet. :green_heart:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

facebook-github-bot commented 2 weeks ago

This pull request was exported from Phabricator. Differential Revision: D65480353

facebook-github-bot commented 2 weeks ago

This pull request was exported from Phabricator. Differential Revision: D65480353

SalmanMohammadi commented 2 weeks ago

Hey @RylanC24! Thanks for opening this : )

It looks like the main change is to set the default path to sample q using

    probs = torch.nn.functional.softmax(logits, dim=-1)

    # if q is None, we use the default softmax sampling trick
    if q is None: # <---- q is now None by default
        q = torch.empty_like(probs).exponential_(1)

Is that right? If so that makes sense to me at a high level.

Out of curiousity, what's your use case here? Are you adding this change to use with the generate.py recipe? FWIW we'll eventually be deprecating this recipe (I think) to use the dev/generate_v2.py recipe which is significantly neater and uses this proposed behaviour by default since it calls sample directly without going through generate_next_token. I think this change makes sense to fix the existing generation utils, though.

cc @joecummings

RylanC24 commented 2 weeks ago

@SalmanMohammadi yes, that's right. The use-case is a subtle one but comes up anytime you want to trim the embedding and/or output layers to remove unnecessary tokens (e.g., if the output space is constrained and we don't want to keep 128k x 2048 dimensional vectors in our model). The issue comes up when you want to map this trimmed output space back to the original (so we can still use the same tokenizer). In this situation the dimension of the output logits will not match the dimension of the embedding layer, leading to an error when we try to divide the logits by q (which was previously set to the size of the embedding layer).

RylanC24 commented 2 weeks ago

@SalmanMohammadi forgot to add that yes, the new generator shouldn't have this issue but this fix will allow us to patch the old one in the meantime :-)

SalmanMohammadi commented 2 weeks ago

@SalmanMohammadi yes, that's right. The use-case is a subtle one but comes up anytime you want to trim the embedding and/or output layers to remove unnecessary tokens (e.g., if the output space is constrained and we don't want to keep 128k x 2048 dimensional vectors in our model). The issue comes up when you want to map this trimmed output space back to the original (so we can still use the same tokenizer). In this situation the dimension of the output logits will not match the dimension of the embedding layer, leading to an error when we try to divide the logits by q (which was previously set to the size of the embedding layer).

Thanks! I have a couple points:

1) This fix won't actually work for when we have rng right? I'm not sure I see an immediate neat solution here though, is there a way to infer the size of the output space here? rng is just used for PPO so it'd be a very rare interaction.

2) How annoying would it be to add a test for this? We have some tests in tests/torchtune/generation/test_generation.py which build some dummy models. Would it be simple enough to create another dummy model fixture which has the embedding replaced with a trimmed embedding, and ensures that we can correctly generate without any issues?

codecov-commenter commented 2 weeks ago

Codecov Report

Attention: Patch coverage is 0% with 6 lines in your changes missing coverage. Please review.

Project coverage is 25.18%. Comparing base (9eced21) to head (6cae056). Report is 9 commits behind head on main.

Files with missing lines Patch % Lines
torchtune/generation/_generation.py 0.00% 6 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #1970 +/- ## =========================================== - Coverage 68.40% 25.18% -43.22% =========================================== Files 311 311 Lines 16973 17038 +65 =========================================== - Hits 11610 4291 -7319 - Misses 5363 12747 +7384 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

RylanC24 commented 2 weeks ago

Thanks! I have a couple points:

  1. This fix won't actually work for when we have rng right? I'm not sure I see an immediate neat solution here though, is there >a way to infer the size of the output space here? rng is just used for PPO so it'd be a very rare interaction.

Yes, it won't work when an rng is used but I figured these were both pretty niche use-cases that are unlikely to clash. Since there's already a plan to migrate to the new generator where this won't be an issue I think the risk is pretty minimal to ignore this very corner use-case for the time being. wdyt?

  1. How annoying would it be to add a test for this? We have some tests in tests/torchtune/generation/test_generation.py >which build some dummy models. Would it be simple enough to create another dummy model fixture which has the >embedding replaced with a trimmed embedding, and ensures that we can correctly generate without any issues?

This is doable but would be a bit annoying since the vocab pruned model types are not defined in the torchtune repo. The existing tests should validate that the normal generation use-cases are not affected and I've verified with our vocab pruned model definitions that it works as expected. Again, since this is really just a stopgap fix until the new generator is released maybe we can forgo the additional tests?

facebook-github-bot commented 2 weeks ago

This pull request was exported from Phabricator. Differential Revision: D65480353

SalmanMohammadi commented 2 weeks ago

Yeah makes sense to me. I'll verify it works OK with compile in a follow up :)