LargeWorldModel / LWM

Large World Model With 1M Context
https://largeworldmodel.github.io/
Apache License 2.0
7.12k stars 550 forks source link

Memory requirements #7

Open loretoparisi opened 8 months ago

loretoparisi commented 8 months ago

It would be worth to provide the measured memory requirements for inference Text Models at 32K, 128K,256K,512K and 1M tokens context window in both PyTorch and JAX.

wilson1yan commented 8 months ago

If using vLLM for inference (PyTorch model, FP16), I believe we used:

For each of the above, serving 1 model with tensor parallelism over the given number of devices. With 8 80GB A100s, I think the limit was around 650K - 700K tokens. In vLLM, it prints out the max number of tokens supported by giving the number of blocks for caches allocated, so it should be easy to tell if you're using GPUs with different amounts of memory.

For Jax, I'm not too sure what intermediate requirements were, but we needed a v4-256 to do inference on 1M tokens (full FP32 inference). I think more optimization can be made (e.g. half-precision, quantization, etc.) to make the requirements smaller. Even at full precision, the requirements seemed higher than I expected, and there might be some Jax / XLA optimizations to be made (e.g. keep it from padding certain dimensions, which we originally had a lot of trouble with).

blazorin commented 8 months ago

Any recommendation to run the model on smaller GPUs (T4). It runs out of memory (jax).

Playerrrrr commented 7 months ago

@wilson1yan Can you share the shell/bash script for setting up the inference server via vLLM for PyTorch model, FP16?

If using vLLM for inference (PyTorch model, FP16), I believe we used:

* 1 80GB A100 for 32K

* 2 80GB A100s for 128K

* 4 80GB A100s for 256K

* 8 80GB A100s for 512K

For each of the above, serving 1 model with tensor parallelism over the given number of devices. With 8 80GB A100s, I think the limit was around 650K - 700K tokens. In vLLM, it prints out the max number of tokens supported by giving the number of blocks for caches allocated, so it should be easy to tell if you're using GPUs with different amounts of memory.

For Jax, I'm not too sure what intermediate requirements were, but we needed a v4-256 to do inference on 1M tokens (full FP32 inference). I think more optimization can be made (e.g. half-precision, quantization, etc.) to make the requirements smaller. Even at full precision, the requirements seemed higher than I expected, and there might be some Jax / XLA optimizations to be made (e.g. keep it from padding certain dimensions, which we originally had a lot of trouble with).

xloem commented 7 months ago

I’m thinking an attention kernel optimization like top-k would be appropriate here. Could a user calculate their own position_ids and pass a subset of the tokens, maybe make multiple passes and drop tokens that don’t impact the results?

MoonRide303 commented 6 months ago

Aren't those requirements a bit high in case of 7B w/ 32k context? Mistral 7B 0.2 (32k context) works absolutely fine on consumer grade GPUs (especially when using quantized versions, like high quality Q6_K GGUFs).