pytorch / torchchat

Run PyTorch LLMs locally on servers, desktop and mobile
BSD 3-Clause "New" or "Revised" License
3.4k stars 225 forks source link

Granite code support #1336

Open gabe-l-hart opened 4 weeks ago

gabe-l-hart commented 4 weeks ago

Dependencies

This PR is part of a sequence in support of adding Granite Code. It depends on merging the following PRs:

Issues

Closes #1262

Description

This PR adds support for Granite Code in 3B and 8B sizes. Given current limitations with the export of tokenizers, they will only work in the python environment with this PR.

Discussion

Usage

To test using these models, I did it both by running with the aliases and by running pointing directly at the checkpoint/tokenizer:

# Run with alias
python torchchat.py generate granite-code \
  --prompt "Write a python function to sort numbers and strings with numeric prefixes"

# Run with direct reference to artifacts
python torchchat.py generate \
  --prompt "Write a python function to sort numbers and strings with numeric prefixes" \
  --checkpoint-path $HOME/models/ibm-granite/granite-3b-code-instruct-128k/model.pth \
  --tokenizer-path $HOME/models/ibm-granite/granite-3b-code-instruct-128k/tokenizer.json \
  --params-path torchchat/model_params/Granite-3B-Code.json

Open Questions

There are several outstanding issues, beyond the upstream tokenizers PR, that need to be solved before this PR is ready for full review:

pytorch-bot[bot] commented 4 weeks ago

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1336

Note: Links to docs will display an error until the docs builds have been completed.

:white_check_mark: No Failures

As of commit 10918a124869bbcdef1cdd8a19ac2a0b8d6605c3 with merge base 6895a18b994bf910c3d6d6c9d55c93504448ec90 (image): :green_heart: Looks good so far! There are no failures yet. :green_heart:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

gabe-l-hart commented 4 weeks ago

Also, I used the following script to perform conversion of a pre-existing HF snapshot. It's similar to the if __name__ == "__main__" block in convert_hf_checkpoint.py:

convert_existing_checkpoint.py ```py #!/usr/bin/env python """ Simple script to convert an existing HF snapshot into torchchat format """ # Standard import argparse from pathlib import Path # Local from torchchat.cli.convert_hf_checkpoint import convert_hf_checkpoint, convert_hf_checkpoint_to_tune def main(): parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("checkpoint_dir", help="Directory containing HF checkpoint") parser.add_argument("--name", "-n", default=None, help="Name to use for the model") parser.add_argument("--torchtune", "-t", action="store_true", default=False, help="Convert to torchtune format") args = parser.parse_args() if args.torchtune: convert_hf_checkpoint_to_tune(model_dir=Path(args.checkpoint_dir), model_name=args.name) else: convert_hf_checkpoint(model_dir=Path(args.checkpoint_dir), model_name=args.name) if __name__ == "__main__": main() ```
gabe-l-hart commented 3 weeks ago

I confirmed that it was falling back to the llama2 chat formatter because it wasn't using tiktoken. I've added basic jinja2 chat template support when using the HF tokenizer.

mikekgfb commented 3 weeks ago

A pointer to this PR and the example commands from the PR description would make a good starting point for docs/new_model.md to (at least partially?) address #1038 / #1041 in conjunction with some explanatory text

# wget artifacts here
# Run with direct reference to artifacts
python torchchat.py generate \
  --prompt "Write a python function to sort numbers and strings with numeric prefixes" \
  --checkpoint-path $HOME/models/ibm-granite/granite-3b-code-instruct-128k/model.pth \
  --tokenizer-path $HOME/models/ibm-granite/granite-3b-code-instruct-128k/tokenizer.json \
  --params-path torchchat/model_params/Granite-3B-Code.json

Explain how to add to model list....

# Run with alias
python torchchat.py generate granite-code \
  --prompt "Write a python function to sort numbers and strings with numeric prefixes"

if added to .ci/scripts/run-docs new_model it might also make a testcase for the features used in granite.

gabe-l-hart commented 2 weeks ago

@Jack-Khuu I'm a bit stumped trying to get the 8B model working. I'm trying to mentally diff the Attention implementation in torchchat vs transformers to see if I can find anything that would indicate something behaving differently with Grouped Query Attention.

I'm not really following the different way that the torchchat version is manipulating the tensors for tensor parallel inference (need to do some background reading there), but this feels like it's got to be close to the root of the issue. The only other place that I could imagine things going wrong is in the unpacking of the unified wqkv here. Any insight you can offer would be much appreciated!

Results with 3B ``` ?> python torchchat.py generate granite-code-3b --prompt "Write a python hello world function" NumExpr defaulting to 16 threads. PyTorch version 2.6.0.dev20241002 available. lm_eval is not installed, GPTQ may not be usable W1108 13:18:36.747000 52813 site-packages/torch/distributed/elastic/multiprocessing/redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs. Using device=mps Loading model... Time to load model: 3.86 seconds ----------------------------------------------------------- Write a python hello world function `​``python def say_hello(): print("hello world") `​`` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Generated 19 tokens Time for inference 1: 1.6639 sec total Time to first token: 0.4289 sec with parallel prefill. Total throughput: 12.0199 tokens/sec, 0.0832 s/token First token throughput: 2.3316 tokens/sec, 0.4289 s/token Next token throughput: 15.3844 tokens/sec, 0.0650 s/token Bandwidth achieved: 86.74 GB/s *** This first iteration will include cold start effects for dynamic import, hardware caches. *** ======================================== Average tokens/sec (total): 12.02 Average tokens/sec (first token): 2.33 Average tokens/sec (next tokens): 15.38 ``` **NOTE** (because I feel compelled): The above snippet uses [zero-width-spaces](https://en.wikipedia.org/wiki/Zero-width_space) to escape the triple backticks inside the code blocks, so copy-paste at your own peril!
Results with 8B ``` ?> python torchchat.py generate granite-code-8b -p "Write a python hello world function" usage: torchchat [-h] {chat,generate,browser,export,download,list,remove,where,server,eval} ... torchchat: error: unrecognized arguments: -p Write a python hello world function (torchchat2) ghart@Mac [torchchat GraniteCodeSupport ?]$ python torchchat.py generate granite-code-8b --prompt "Write a python hello world function" NumExpr defaulting to 16 threads. PyTorch version 2.6.0.dev20241002 available. lm_eval is not installed, GPTQ may not be usable W1108 13:13:21.744000 51816 site-packages/torch/distributed/elastic/multiprocessing/redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs. Using device=mps Loading model... Time to load model: 11.67 seconds ----------------------------------------------------------- Write a python hello world function function function function function function function function function function function ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Generated 11 tokens Time for inference 1: 7.5729 sec total Time to first token: 4.8976 sec with parallel prefill. Total throughput: 1.5846 tokens/sec, 0.6311 s/token First token throughput: 0.2042 tokens/sec, 4.8976 s/token Next token throughput: 4.1117 tokens/sec, 0.2432 s/token Bandwidth achieved: 26.17 GB/s *** This first iteration will include cold start effects for dynamic import, hardware caches. *** ======================================== Average tokens/sec (total): 1.58 Average tokens/sec (first token): 0.20 Average tokens/sec (next tokens): 4.11 ```
Jack-Khuu commented 2 weeks ago

Thanks for the details @gabe-l-hart

I'll try to give it a gander this weekend. It's weird that 3B works, but 8B doesn't. I assume they use the same template so that at least clears that part

byjlw commented 1 week ago

Looks like this has been open for a several weeks now. Yeah, the template thing is super hacky right now and I knew it was going to hang up our ability to add new models. In general we need to make a smoother path for new adding new models with different architectures, templates and storage locations.

It's been on @varunfb and @Jack-Khuu 's plate for a while but they've been swamped with other work. Fortunately it's planning season and the design for this is on the list. @gabe-l-hart would love to get your feedback on how best to support folks like yourself.

gabe-l-hart commented 1 week ago

Thanks @byjlw! I definitely understand juggling priorities. The path to adding new models in the model_params and model_config is relatively straightforward (could use a doc, but TBH I never read docs anyway, so easy-to-read code is always best). The real challenge has come up around places where the models differ from the llama series models. In particular, Granite Code uses the llama architecture, but uses several optional bits that the Meta Llama models don't (e.g. HF tokenizers, tied embeddings). Getting these pieces to work has been a decently steep learning curve (fun though!). I think the thing that would be most helpful would be some kind of compatibility matrix doc that shows architectures that have support, sub-features within architectures, and which "layers" they're supported in (e.g. python, c++, executorch). This would help a lot in figuring out where to dig in to add new model support.

For the specific issues for Granite Code, the place I'm a bit stuck is trying to figure out why the 8B model is flopping while the 3B model is working just fine. My gut is that it has something to do with the alternate attention mechanism in TC, but I'm not deeply versed in attention enough to spot it quickly. The only architectural difference between 3B and 8B is the use of grouped query attention, so it's either something there or there's some incompatibility between the attention implementations in transformers and TC that's only being exercised by the specific weights of the 8B. Any help and/or expert bug spotting would be much appreciated!

gabe-l-hart commented 1 week ago

I just rebased on main and it now looks like even the 3b model is producing only a single token as output in chat mode. Will try to get to the bottom of it.

mikekgfb commented 1 week ago

I just rebased on main and it now looks like even the 3b model is producing only a single token as output in chat mode. Will try to get to the bottom of it.

Have you tried bisecting the 3B fail? Even if the change was legit and necessary, the type of change that would break the 3B model might give insight in how to "fix" both the 3B and 8B models? .

mikekgfb commented 1 week ago

The real challenge has come up around places where the models differ from the llama series models. In particular, Granite Code uses the llama architecture, but uses several optional bits that the Meta Llama models don't (e.g. HF tokenizers, tied embeddings). Getting these pieces to work has been a decently steep learning curve (fun though!).

I'm a bit surprised by this because chatgpt had this to say (understanding that I'm quoting chatgppt about an IBM model to an IBMer, so skating on seriously thin ice!!!):

what tokenization scheme does the ibm granite model use

Searched 4 sites The IBM Granite models, including its base and instruction-tuned variants, utilize the Llama2 tokenizer for tokenization. This choice aligns with the models’ architectural similarity to Meta's Llama2 series, such as the Granite-7b model, which follows the Llama2-7B architecture and employs similar tokenization strategies. These tokenizers are designed to handle diverse data sources, including programming languages and natural language, ensuring compatibility and efficiency in tasks like code synthesis and language understanding​

So in theory, SentencePiece should do the trick? Is it the pre and post processing with regexps? (I think I saw some discussion about regexps in one of your PRs or issues?)

In any event, it's cool that we have HF tokenizers because they are a proper superset of SentencePiece+TikToken. (I think @lessw2020 and @kwen2501 had also added some HF tokenizer support for distributed if I remember correctly?)

gabe-l-hart commented 1 week ago

Have you tried bisecting the 3B fail? Even if the change was legit and necessary, the type of change that would break the 3B model might give insight in how to "fix" both the 3B and 8B models?

That's on my todo list for my next chunk of uninterrupted dev time! I'm hoping that will be today.

I'm a bit surprised by this because chatgpt had this to say (understanding that I'm quoting chatgppt about an IBM model to an IBMer, so skating on seriously thin ice!!!):

Heh, as you know I'm sure, IBM is a big place, so I'm definitely doing a lot of learning myself in this space. My info from the models team is that we've been using the starcoder tokenizer up until now (including Granite Code and the Granite 3.0 series). When first trying to understand how best to support that in torchchat, I was missing a lot of knowledge about sentencepiece, so was working off of the tokenizer_config.json in HF. I suspect it would be possible to reverse-convert from tokenizers back to sentencepiece for this config, but I haven't done that work yet since I was already halfway down the rabbit hole of tokenizers support. We can certainly look into that as an alternative approach if the preference is to avoid the complexity of the c++ tokenizer buildout.

gabe-l-hart commented 1 week ago

@Jack-Khuu @mikekgfb @byjlw I figured out where the issues were coming from. It was two things:

  1. The logic was always inserting a bos token at the beginning of the sequence which the 3b model was sometimes ok with, but the 8b never was
    • To solve this, I added tokenizer_prepend_bos as a parameter in TransformerArgs and ModelArgs. It seemed a little klunky to plumb it through multiple blobs, but this got things working for both models with raw generation
  2. The chat template logic was not robust beyond llama2 and llama3 templating. Solving this resulted in a fair bit of refactoring:
    • Refactor the Llama2ChatFormatter and Llama3ChatFormatter to encapsulate all logic in a single abstract method encode_dialog_prompt
    • Remove all formatter-specific logic from the primary generation loop in def chat
    • Add the HFTokenizerChatFormatter
    • Plumb the ability to use the chat template with jinja2 through HFTokenizer
      • NOTE: jinja2 was already a transitive dependency, so I just formalized it

To get to the bottom of all of this, I also tweaked the logging a bit. There was already a hard-coded logging config call in cli, so I just added the ability to parse the LOG_LEVEL env var to set it. I also added a fair number of additional log lines and uncommented some that were there but commented out.

NOTE: Many of the existing log lines were using f-strings which will cause the string to be interpolated regardless of whether the logger/level are enabled. I switched all of these to use lazy interpolation with percent-encoding so that it's safe to have them uncommented without a performance hit.

Finally, I was getting lost trying to make sure I didn't break anything in the chat templating, so I bit the bullet and added some basic unit tests. They only cover the chat formatting, but they're a place to start. I did not go any further with unit testing, including not adding pytest as a dependency or adding any CI steps to invoke the tests. If you're interested, I'd be happy to push on unit testing, but I didn't want to lump that conversation into this PR.

byjlw commented 6 days ago

Thanks @gabe-l-hart Yeah a lot of this feedback resonates really well with me, and resolving it has already made it on our H1 roadmap such as making it easy to have model specific templates, adding test infra and guidelines around tests, abstracting and making the code more modular so that there is a specific module for core with well defined APIs that the CLI and API can use. We will also figure out the release strategy and publish two or three specific pip packages.

Will be able to share the details soon and will have them as RFCs on GH so everyone can comment and contribute.