predibase / lorax

Multi-LoRA inference server that scales to 1000s of fine-tuned LLMs
https://loraexchange.ai
Apache License 2.0
2.04k stars 135 forks source link

[Question] SGMV support for < rank 8 #241

Closed markovalexander closed 6 months ago

markovalexander commented 6 months ago

Hello.

This check holds when adapter's rank is at least 8 x num_shards (can be seen here).

Does it mean that, i.e. if I trained adapter with r=8 for llama-2 70b, I will not be able to utilise SGMV kernels for inference (since llama-2-70b requires at least 2 and preferably 4 shards) and this will lead to significant latency increase (I actually observe it in load testing, attaching gpu kernel types profiling result for tp adapter inference).

Screenshot 2024-02-12 at 18 32 48

Is it possible to decrease MIN_SGMV_RANK or it should be fixed at 8? Does it mean that I must train adapters with rank at least 32 to use SGMV with 4 shards?

Thank you for your response!

tgaddair commented 6 months ago

Hi @markovalexander, good callout. That is a limitation today since the SGMV kernel requires at least rank 8, but we could workaround this by padding any adapters up to rank 8 with 0s to remove the need for MIN_SGMV_RANK altogether. The only downside would be the (pretty minor) increase in memory for the adapters.

If that sounds reasonable to you, it should be a relatively straightforward change to make. I can aim to pick it up hopefully some time this week if no one wants to get to it first.

markovalexander commented 6 months ago

@tgaddair thanks for a quick reply!

yes, I think it should be added cause right now inefficient GPU utilisation leads to very poor performance with small adapters, which is quite unexpected.

I am actually not sure if not utilised sgmv kernels is the main reason for latency increases though, but I don't know what else it could be. WDYT? I think if this solves the issue that it definitely should be tested.

tgaddair commented 6 months ago

Hey @markovalexander, PR #256 should address this issue so SGMV can be used even with tensor parallelism.

It's possible the lack of SGMV support contributed to latency, this would be most noticeable when using multiple adapters per batch. In most cases, slow tensor parallelism is due to GPU-to-GPU network latency when the devices are conncted via PCIe rather than NVLink. Do you also notice the slowdown when running against the base model (without an adapter)?

markovalexander commented 6 months ago

Hi @tgaddair

No, no slowdowns when running the base model. Devices should be connected via NVLink.

markovalexander commented 6 months ago

Hello @tgaddair I am experimenting with 70b llama model in main branch, building docker image from scratch. It is still not working: now I guess the issue with MAX_RANK_CUSTOM. It is too small for llama 70b model with GQA.

tgaddair commented 6 months ago

Hey @markovalexander, what is the rank of the adapter you need to use? We can definitely add more ranks for the kernel if needed.

markovalexander commented 6 months ago

A matrix is of shape [8192, 8] so r = 8, but the error is "8192 is larger than MAX_RANK_CUSTOM". MAX_RANK_CUSTOM=128 so I guess it's attention head dimension for MHA, but 70b llama has GQA so attention "head" is larger than 128, am I right?

tgaddair commented 6 months ago

Hey @markovalexander, that kind of error would suggest to me that the LoRA weights are stored transposed, if it's comparing the head dim of 8192 with the max rank of 128.

Do you have a script I can use to repro the error? Also, can you share the full error message you're seeing?

markovalexander commented 6 months ago

@tgaddair thanks for looking into an issue. I made a script that you can use to create the same shape/dtype adapter_model.safetensors and adapter_config.json I use in my work (we use llama-2 70b model):

import torch
import os
import safetensors.torch as st

adapter_folder = # PUT PATH HERE

lora_config = {
    'alpha_pattern': {},
    'auto_mapping': None,
    'base_model_name_or_path': 'meta-llama/Llama-2-70b-chat-hf',
    'bias': 'none',
    'fan_in_fan_out': False,
    'inference_mode': False,
    'init_lora_weights': True,
    'layers_pattern': None,
    'layers_to_transform': None,
    'lora_alpha': 16.0,
    'lora_dropout': 0.0,
    'modules_to_save': None,
    'peft_type': 'LORA',
    'r': 8,
    'rank_pattern': {},
    'revision': None,
    'target_modules': ['q_proj', 'v_proj', 'k_proj', 'o_proj'],
    'task_type': 'CAUSAL_LM'
}

with open(os.path.join(adapter_folder, 'adapter_config.json'), 'w') as f:
    json.dump(lora_config, f)

shapes_dict = {
    ('k_proj', 'lora_A'): torch.Size([8, 8192]),
    ('k_proj', 'lora_B'): torch.Size([1024, 8]),
    ('o_proj', 'lora_A'): torch.Size([8, 8192]),
    ('o_proj', 'lora_B'): torch.Size([8192, 8]),
    ('q_proj', 'lora_A'): torch.Size([8, 8192]),
    ('q_proj', 'lora_B'): torch.Size([8192, 8]),
    ('v_proj', 'lora_A'): torch.Size([8, 8192]),
    ('v_proj', 'lora_B'): torch.Size([1024, 8])
}

def get_state_dict(shapes_dict):
    pattern = "base_model.model.model.layers.{layer_idx}.self_attn.{layer}.{lora_layer}.weight"

    lora_state_dict = {}

    for layer_idx in range(80):
        for attn_layer in ['k_proj', 'q_proj', 'v_proj', 'o_proj']:
            for lora_layer in ['lora_A', 'lora_B']:
                name = pattern.format(layer_idx=layer_idx, layer=attn_layer, lora_layer=lora_layer)
                shape = shapes_dict[(attn_layer, lora_layer)]
                tens = torch.randn(*shape, dtype=torch.bfloat16) * 1e-3
                lora_state_dict[name] = tens
    return lora_state_dict

lora_state_dict = get_state_dict(shapes_dict)
st.save_file(lora_state_dict, os.path.join(adapter_folder, "adapter_model.safetensors"))

Next, you can build docker from main, I used this commit, it should support SGMV kernel. And just send a request with newly created adapter.

You will see 2 strange things. 1st one is "2024-02-27T18:58:47.282778Z INFO lorax_launcher: model.py:189 Adapter sharded in 21.81 seconds" -- I believe it's too long? Adapter map was loaded in 0.02 seconds.

Secondly, you have to add

        from loguru import logger

        for rank_data in self.rank_data.values():
            logger.info(f"{rank_data.rank // pg.size() = }\t {MAX_RANK_CUSTOM}")

here to see that all these values are bigger than MAX_RANK_CUSTOM, which further leads to False condition and not using SGMV here .

2024-02-27T19:08:25.405387Z  INFO lorax_launcher: lora.py:52 rank_data.rank // pg.size() = 2048  128
2024-02-27T19:08:25.406226Z  INFO lorax_launcher: lora.py:52 rank_data.rank // pg.size() = 512   128

I tried to remove transposes here but it didn't work too :)

Env details: here is how I run the docker:

context_length=4096
batch_size=32
waiting_served_ratio=1.2
max_waiting_tokens=20
docker run --gpus='"device=0,1,2,3"' --rm --shm-size 1g \
    -p $port:$port --network host \
    -v $volume:/data \
    -v /home/alexander/adapters:/adapters \
    $image_name:$image_tag --model-id /data/$model --num-shard 4 \
    --max-input-length $(($context_length-1)) \
    --max-total-tokens $context_length \
    --max-batch-prefill-tokens $(($batch_size*$context_length)) \
    --waiting-served-ratio $waiting_served_ratio \
    --max-waiting-tokens $max_waiting_tokens \
    --max-stop-sequences 10 \
    --cuda-memory-fraction 0.99 \
    --hostname 0.0.0.0 --port $port

It uses 4 A100 gpus, Driver Version: 535.104.05 CUDA Version: 12.2, no other GPU/CPU jobs on instance. GPUs are connected with NVLink, we tested that it works perfectly in training scripts, I actually don't know how to verify it is used in lorax in my case

markovalexander commented 6 months ago

@tgaddair Hello! Is there any updates on the issue?