zkyredstart / LLaVA-NPU

Run LLaVA on a NPU device, e.g. Ascend 910B
Apache License 2.0
3 stars 0 forks source link

LLaVA-NPU: Running LLaVA on a NPU device, e.g., Ascend 910B

In this project, we transfer the LLaVA from the CUDA device to the NPU device. If you think our project helpful, give our project a star. Thanks for your support!

Change history

[2024-09-02]: We support the common visual projectors, e.g. LDPv2, Resampler. Try it!

[2024-08-27]: We add SigLip encoder to the LLaVA-NPU.

[2024-08-25]: We reproduce the results on the MMBench. Weights are coming soon.

[2024-08-15]: We create the project and update the source code.


<1> Install the LLaVA. ```pip install -e .["train"]``` ### Train and Test <1> You can train and test LLaVA as those in the official repo! If you are in china, you can download the model from modelscope. <2> **Important! ! ! ! !**: LLaVA-NPU does not support lora tuing and zero3-offload. Please use the full tuning. We train LLaVA on 8 Ascend 910B NPUs with 65GB memory. <3> Training details. The hyper-parameters used in the pertraining and visual instruction tuning are as followed. 1. Pretraining | Hyperparameter | Global Batch Size | Learning rate | Epochs | Max length | Weight decay | | --- | ---: | ---: | ---: | ---: | ---: | | LLaVA-v1.5-7B | 256 | 1e-3 | 1 | 2048 | 0 | 2. visual instruction tuning | Hyperparameter | Global Batch Size | Learning rate | Epochs | Max length | Weight decay | | --- | ---: | ---: | ---: | ---: | ---: | | LLaVA-v1.5-7B | 64 | 1e-5 | 1 | 2048 | 0 | <4> Model Performance comparison. | Model | Image encoder | Language Model | Projector | MMBench | MMVet | | --- |--- | --- | --- |---: |---: | | LLaVA-v1.5-7B (official) | CLIP | Vicuna-7B | MLP |66.7 |31.1 | LLaVA-v1.5-7B (ours) | CLIP | Vicuna-7B | MLP |67.7 | 32.2 | LLaVA-v1.5-7B (ours) | SigLip | Vicuna-7B | MLP |66.4 | - | LLaVA-v1.5-7B (ours) | CLIP | Vicuna-7B | Adaptive Pool |64.6 |- | LLaVA-v1.5-7B (ours) | CLIP | Vicuna-7B | Resampler |63.1 | 27.1 | LLaVA-v1.5-7B (ours) | CLIP | Vicuna-7B | LDPv2 |65.7 | 28.9 | LLaVA-v1.5-7B (ours) | CLIP | Vicuna-7B | TokenPacker |63.1 | - | LLaVA-v1.5-7B (ours) | CLIP | Vicuna-7B | C-Abstract |65.1 | 31.8 ### Core code <1> LLaVA-NPU changes the flash_atten implementation. The code can be found in [here](llava/train/llama_npu_monkey_patch.py). <2> We modify the evaluation code in LLaVA. The code is [here](llava/eval). <3> We support nn.MultiHeadAttention on the NPU device. ``` import torch import torch.nn as nn import torch_npu import math class MultiheadfusionAttention(nn.Module): """ MultiHeadAttention Implementation on the NPU device """ def __init__(self, d_model, h): super().__init__() assert d_model % h == 0 # We assume d_v always equals d_k self.d_k = d_model // h self.h = h self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)]) self.output_linear = nn.Linear(d_model, d_model) def forward(self,query,key,value,attn_mask=None,dropout=1.): # import pdb;pdb.set_trace() batch_size = query.size(0) ns = key.size(0) # 1) Do all the linear projections in batch from d_model => h x d_k query, key, value = [l(x) for l, x in zip(self.linear_layers, (query, key, value))] scale = 1 / math.sqrt(self.d_k) attn_output = torch_npu.npu_fusion_attention(query, key, value, self.h, pse=None, padding_mask=None, atten_mask=attn_mask, scale=scale, keep_prob=dropout, input_layout="SBH", pre_tockens=65536, next_tockens=0, inner_precise=0) return attn_output # Usage inputs = torch.rand(1,2304,1024) Attention = MultiheadfusionAttention(1024,8) outputs = Attention(inputs)[0] ``` ### Acknowledgement <1> We would like to express the sincere thanks to [this repo](https://github.com/HelloWorldBeginner/LLaVA/tree/main) for its implementation on the NPU. Our project is based on it! <2> Many thanks to LLaVA for its great contribution to MLLM community!