xai-org / grok-1

Grok open release
Apache License 2.0
49.46k stars 8.33k forks source link

Convert to pytorch model to use transformers from huggingface #202

Open wenmengzhou opened 6 months ago

wenmengzhou commented 6 months ago

Can someone covert this jax model to pytorch model implemented in transformers?

wweevv-johndpope commented 6 months ago

I asked claude3 (it's good but you mileage may vary) how you going to load this thing? it takes many 80gb gpus. https://gist.github.com/johndpope/0aa7b2709bf04c44626c019feb798cfe

LagPixelLOL commented 6 months ago

Edit: There's a better implementation: https://huggingface.co/keyfan/grok-1-hf https://github.com/LagPixelLOL/grok-1-pytorch

davidearlyoung commented 6 months ago

Found this: https://huggingface.co/hpcai-tech/grok-1

Sequential-circuits commented 6 months ago

"For now, a 8x80G multi-GPU machine is required)."

Jeffwan commented 6 months ago

Is there a way to support distributed training? We do not have that many 80G cards.

davidearlyoung commented 6 months ago

Is there a way to support distributed training? We do not have that many 80G cards.

Might be possible, for example, with HF's (huggingface) built in DDP lib's: https://huggingface.co/blog/pytorch-ddp-accelerate-transformers. The original model was likely run as distributed under xai.

There may be other platforms and/or libraries that could be used for distributed computing for neural networks out side of the example option I mentioned with huggingface. It's a lite interest of mine, though I'm not highly experienced with running LLM's in distributed setups.

It seems there is now quantization of Grok-1 that takes it's size down to 120~ GB for file and memory usage. Instead of the 630+ GB. (Not considering additional mem usage for doing forward pass inferences. It's a rough calc of model file size for disk and memory usage before and at time of model loading, just before inference or training use. Useful for getting started.) This can move this into the realm of high end consumer hardware. Which could allow more accessible inference with likely decent enough performance and accuracy out of the model. We've seen this done before in the past with very large models. (Falcon from TII for example.)

As for training grok, it's going to take additional memory for the training process. How much, I wouldn't know. It's something to consider that someone with more experience with training LLMS could likely answer.

Side note, it's possible to train or finetune an already pre-trained model in quantized form.

Overall , it's likely possible that with a quality quantization of the model, could be distributed across multiple compute nodes in quant form for inference and/or training at a resource requirement level that may be more available for you.

Exploration of what and how the model can be run on are still ongoing. Like before, someone will likely figure out how to run it decently on lower end hardware then what it started with. Which could then lead to stabilizing that for real use and running outside of prototype concepts for the rest of us.

Jeffwan commented 6 months ago

@davidearlyoung Thanks for all the details. Do you know whether distributed inference works or not? We have some A100-40G cards and do not like to sacrifice use quantization. We are think whether it's possible to put it on two machines with both TP and PP enabled.

davidearlyoung commented 6 months ago

@Jeffwan

Do you know whether distributed inference works or not?

Moving the conversation on. I just want to make sure that I understand you correctly:

We have some A100-40G cards ...

... and do not like to sacrifice use quantization.

  • I think what you are trying to say is that you do not intend to use quantized forms of grok-1 on your A100-40G cards. Let me know if I understood that correctly.

We are think whether it's possible to put it on two machines with both TP and PP enabled.

  • Are you trying to say that you and your team think that it is possible to load grok on two machines using Tensor Parallelism and Pipeline Parallelism?
  • I'm also assuming that you will have 1 or more A100-40G's on each machine. Let me know if I understood what you were saying there as well.

I have taken a few very brief looks into distributed computing for LLM's over the years. From what I'm understanding (which I could be wrong to certain degrees), if you are wanting to run the full model in GPU for inference with pytorch, in theory, you will need enough vram across your distributed compute system to hold the same size as the model on file. Plus additional vram to deal with model overhead. That additional overhead needed could vary base on a lot of factors. (I think about roughly 20% for basic inference might be enough.)

Here is some quick napkin math:

Old model memory calculator: https://huggingface.co/spaces/hf-accelerate/model-memory-usage (This is a bit aged. Still might be useful.)

Some multi GPU training and inference info from the perspective of the transformers library:

A lot of what I've learned about LLMs basics has been through huggingface. Which is why most of the links I've shared have been through them.

Jeffwan commented 6 months ago

@davidearlyoung I really appreciate your informative analysis! Thanks a lot!

I personally do not know if distributed inference works for grok-1 in pytorch.

Yes! that's my questions to the community as well. I see lots of people do not have that many A100-80G to run this model, what's why I am curious whether multi-node inference work. (definitely need TP or PP or TP+PP)

I think what you are trying to say is that you do not intend to use quantized forms of grok-1 on your A100-40G cards. Let me know if I understood that correctly.

yes

xai's Jax grok BF16 at 318~ GB (I'm not sure since I can't speak for it since I'm not a direct user of jax. But if it's similar to pytorch in memory requirements):

A100-40G: (318 + (318 0.2)) / 40 -> 381.6 / 40 = 9.54 (round up to 10) A100-80G: (318 + (318 0.2)) / 80 = 4.77 (round up to 5)

This is exactly the situation we are facing. If we do not use quantized forms of grok-1. Let's say we like bf16 version. Then there's no way to put into a machine with 8 * A100-40G. distributed inference in our env is kind of required techniques.

I did some research and notice some frameworks like TensorRT-LLM and vLLM have some support on TP and PP but notice they have some performance limitations. However, I have not tried with grok-1 and just want to get some feedback from the community to see what's the best practice or recommended way to run distributed inference. (probably the question is not valid since those frameworks do not have the support for this model)

ImtiazSajwani commented 4 months ago

Has anyone seeing system OOM I have 1TB system memory with 8GPU's while using PyT model, but it seems to kill the process with OOM kernel message. Jax, one seems to work fine with 1TB of memory, it is using BF16.