turboderp / exllamav2

A fast inference library for running LLMs locally on modern consumer-class GPUs
MIT License
3.45k stars 257 forks source link

AssertionError: Insufficient space in device allocation #168

Closed Double-bear closed 8 months ago

Double-bear commented 9 months ago

Hello, I attempted to use the exllama2 on the Llama2-70B model with GPU A800, 80GB. During the model loading process, since the model I trained has the context length of 16k, I set both max_seq_len and max_input_len to 16384, and max_attention_size to max_seq_len squared. The gpu split is [75]*8. However, I encountered an error when using a single machine with 8 cards:

AssertionError: Insufficient space in device allocation

checking data as the following:

current_idx:  0 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  1275069184 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 79716941824

current_idx:  0 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 79414935552

current_idx:  0 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 78005633024

current_idx:  0 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 77703626752

current_idx:  0 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 76294324224

current_idx:  0 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 75992317952

current_idx:  0 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 74583015424

current_idx:  0 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 74281009152

current_idx:  0 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 72871706624

current_idx:  0 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 72569700352

current_idx:  1 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  0 allocation_bytes[current_idx] 79716941824

current_idx:  1 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 78307639296

current_idx:  1 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 78005633024

current_idx:  1 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 76596330496

current_idx:  1 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 76294324224

current_idx:  1 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 74885021696

current_idx:  1 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 74583015424

current_idx:  1 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 73173712896

current_idx:  1 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 72871706624

current_idx:  1 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 71462404096

current_idx:  2 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  1275069184 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 79716941824

current_idx:  2 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 79414935552

current_idx:  2 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 78005633024

current_idx:  2 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 77703626752

current_idx:  2 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 76294324224

current_idx:  2 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 75992317952

current_idx:  2 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 74583015424

current_idx:  2 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 74281009152

current_idx:  2 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 72871706624

current_idx:  2 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 72569700352

current_idx:  3 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  0 allocation_bytes[current_idx] 79716941824

current_idx:  3 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 78307639296

current_idx:  3 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 78005633024

current_idx:  3 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 76596330496

current_idx:  3 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 76294324224

current_idx:  3 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 74885021696

current_idx:  3 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 74583015424

current_idx:  3 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 73173712896

current_idx:  3 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 72871706624

current_idx:  3 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 71462404096

current_idx:  4 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  1275069184 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 79716941824

current_idx:  4 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 79414935552

current_idx:  4 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 78005633024

current_idx:  4 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 77703626752

current_idx:  4 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 76294324224

current_idx:  4 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 75992317952

current_idx:  4 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 74583015424

current_idx:  4 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 74281009152

current_idx:  4 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 72871706624

current_idx:  4 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 72569700352

current_idx:  5 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  0 allocation_bytes[current_idx] 79716941824

current_idx:  5 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 78307639296

current_idx:  5 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 78005633024

current_idx:  5 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 76596330496

current_idx:  5 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 76294324224

current_idx:  5 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 74885021696

current_idx:  5 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 74583015424

current_idx:  5 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 73173712896

current_idx:  5 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 72871706624

current_idx:  5 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 71462404096

current_idx:  6 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  1275069184 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 79716941824

current_idx:  6 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 79414935552

current_idx:  6 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 78005633024

current_idx:  6 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 77703626752

current_idx:  6 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 76294324224

current_idx:  6 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 75992317952

current_idx:  6 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 74583015424

current_idx:  6 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 74281009152

current_idx:  6 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 72871706624

current_idx:  6 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 72569700352

current_idx:  7 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  0 allocation_bytes[current_idx] 79716941824

current_idx:  7 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 78307639296

current_idx:  7 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 78005633024

current_idx:  7 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 76596330496

current_idx:  7 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 76294324224

current_idx:  7 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 74885021696

current_idx:  7 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 74583015424

current_idx:  7 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 73173712896

current_idx:  7 len(allocation_bytes):  8

footprint:  1409302528 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 72871706624

current_idx:  7 len(allocation_bytes):  8

footprint:  302006272 dev_scratch:  2617246208 dev_scratch_attn:  68719476864 allocation_bytes[current_idx] 71462404096

current_idx:  8 len(allocation_bytes):  8

In this case, if I don't want to quantize the model, are there any other solutions?

turboderp commented 9 months ago

You should definitely be able to fit the FP16 version of Llama2-70B within 8x80GB. The problem in this case seems to be that it's reserving 68 GB per GPU for the attention weights matrix (2 buffers of 64 attention heads, 16384**2 elements each, 2 bytes per element).

Installing flash-attn should reduce that allocation somewhat, but another option is to simply reduce max_input_len and max_attention_size (e.g. to the default 2048, 2048**2). You will still be able to run inference on 16384-token sequences, they will simply be run in chunks, substantially reducing the VRAM required for temporary buffers and attention weights.

It's a tradeoff of course, with attention on eight 2048-token chunks taking a bit more time than a single 16384-token chunk (though the output is the same), but without flash-attn available it does save 67 GB of VRAM per GPU, so it's probably worth considering. Of course with 8x80GB available you could probably go a lot higher, but there are diminishing returns at some point, whereas the attention matrix keeps scaling quadratically.

Double-bear commented 9 months ago

You should definitely be able to fit the FP16 version of Llama2-70B within 8x80GB. The problem in this case seems to be that it's reserving 68 GB per GPU for the attention weights matrix (2 buffers of 64 attention heads, 16384**2 elements each, 2 bytes per element).

Installing flash-attn should reduce that allocation somewhat, but another option is to simply reduce max_input_len and max_attention_size (e.g. to the default 2048, 2048**2). You will still be able to run inference on 16384-token sequences, they will simply be run in chunks, substantially reducing the VRAM required for temporary buffers and attention weights.

It's a tradeoff of course, with attention on eight 2048-token chunks taking a bit more time than a single 16384-token chunk (though the output is the same), but without flash-attn available it does save 67 GB of VRAM per GPU, so it's probably worth considering. Of course with 8x80GB available you could probably go a lot higher, but there are diminishing returns at some point, whereas the attention matrix keeps scaling quadratically.

Thank you so much! I will try it again.