Closed NotYourName24 closed 3 months ago
Would be great to get it added. I will mark this as a feature req, hopefully we can add it soon.
While I am trying to port DeepSeek V2, there are a few issues:
deepseek-ai/DeepSeek-Coder-V2-Instruct
in fp16, so there is no really good way to test the ported model for it:(@mzbac you can try running the lite version, can you share your current code I am as well trying to implement the model.
It seems llama.cpp's implementation also cache the whole KV cache
@mzbac you can try running the lite version, can you share your current code I am as well trying to implement the model.
I have only completed the Rope and Attention modules. The Moe module is partly complete. I will try to combine all the modules tonight and see if I can create a draft PR to share the model.
Very exciting, DeepSeek Coder V2 seems to really good at coding
@awni, I just noticed that the mlx MOE's topk seems to have some issues when the number of experts per token increases. It returns a different topk compared to the PyTorch implementation. I'm not sure if this has something to do with MLX. Here is the test code that you can use to reproduce it. This is currently blocking the deepseek v2 porting since it has a large expert set and tokens per expert.
import torch
import numpy as np
import unittest
import mlx.core as mx
def pytorch_greedy_select(scores, top_k):
return torch.topk(scores, k=top_k, dim=-1, sorted=False)
def mlx_greedy_select(scores, top_k):
scores = mx.array(scores)
idx = mx.argpartition(-scores, top_k-1, axis=-1)[:, :top_k]
weight = mx.take_along_axis(scores, idx, axis=-1)
return weight, idx
class TestMoEGatingMethods(unittest.TestCase):
def setUp(self):
np.random.seed(42)
mx.random.seed(42)
torch.manual_seed(42)
self.batch_size = 32
self.num_experts = 8
self.top_k = 2 # change to 6 the test will fail
self.np_scores = np.random.rand(
self.batch_size, self.num_experts).astype(np.float32)
self.torch_scores = torch.from_numpy(self.np_scores)
def test_greedy_select(self):
mlx_weight, mlx_idx = mlx_greedy_select(self.np_scores, self.top_k)
torch_weight, torch_idx = pytorch_greedy_select(
self.torch_scores, self.top_k)
np.testing.assert_equal(mlx_idx, torch_idx.numpy())
np.testing.assert_allclose(mlx_weight, torch_weight.numpy(), rtol=1e-5)
if __name__ == '__main__':
unittest.main()
@vovw, here is my porting for DeepSeek-AI/DeepSeek-Coder-V2-Lite-Instruct. However, the MOE doesn't work and I couldn't figure out why. By looking at DeepSeek's implementation, the greedy topk select looks the same as Mixtral MOE but somehow the output doesn't match. I'm not an expert on that so may need some help from @awni.
@awni, I just noticed that the mlx MOE's topk seems to have some issues when the number of experts per token increases.
In general topk
should return same indices but they don't have to be in the same order. topk
doesn't guarantee that the topk
results are actually sorted in MLX.
hough looking at the two implementations they should both be giving sorted results by default (torch takes a sorted
kwarg which defaults to True
. I ran your test @mzbac and it passes for me. Maybe make sure you are on an up-to-date MLX? pip install -U mlx
and try again.
@awni, I just noticed that the mlx MOE's topk seems to have some issues when the number of experts per token increases.
In general
topk
should return same indices but they don't have to be in the same order.topk
doesn't guarantee that thetopk
results are actually sorted in MLX.hough looking at the two implementations they should both be giving sorted results by default (torch takes a
sorted
kwarg which defaults toTrue
. I ran your test @mzbac and it passes for me. Maybe make sure you are on an up-to-date MLX?pip install -U mlx
and try again.
I did some research on DeepSeek's implementation. It seems like the issue is not with mx.argpartition, they are using unsorted topk and sorted the inds during applying the weight for selected experts, it looks like be slightly different. However, I couldn't figure out what the differences are.
Do you want to send your port so far as a draft PR and we can collaborate on it?
Here is the draft PR https://github.com/ml-explore/mlx-examples/pull/882..
@vovw, here is my porting for DeepSeek-AI/DeepSeek-Coder-V2-Lite-Instruct. However, the MOE doesn't work and I couldn't figure out why. By looking at DeepSeek's implementation, the greedy topk select looks the same as Mixtral MOE but somehow the output doesn't match. I'm not an expert on that so may need some help from @awni.
thanks will look into the implementation. now we just need to figure out how kv caching is done and greedy topk right ?
@vovw, here is my porting for DeepSeek-AI/DeepSeek-Coder-V2-Lite-Instruct. However, the MOE doesn't work and I couldn't figure out why. By looking at DeepSeek's implementation, the greedy topk select looks the same as Mixtral MOE but somehow the output doesn't match. I'm not an expert on that so may need some help from @awni.
thanks will look into the implementation. now we just need to figure out how kv caching is done and greedy topk right ?
I did some testing for rope and attention implementations. It seems to produce the same output as Deekseek's implementation, but Moe doesn't produce the same output as Deekseek's.
testing for rope and attention implementations. It seems to produce the same output as Deekseek's implementation
how do you test your implementation @mzbac
By the way; just an idea but this specific model size (for most consumer grade hardware) & use-case (i.e. you want your coding assistant to be fast at predicting) would make it a great case for a speculative decoding example.
testing for rope and attention implementations. It seems to produce the same output as Deekseek's implementation
how do you test your implementation @mzbac
Just compare the tensor of each module against the modeling_deepseek.py
and MLX implementation. for example
import unittest
import torch
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from pathlib import Path
from safetensors.torch import save_file
import glob
from modeling_deepseek import DeepseekV2Attention as PyTorchDeepseekV2Attention, DeepseekV2Config as PyTorchDeepseekV2Config
from mlx_deepseek_v2_attention import DeepseekV2Attention as MLXDeepseekV2Attention, DeepseekV2Config as MLXDeepseekV2Config
def transfer_weights(pytorch_model, model_path: Path):
# Save PyTorch weights to safetensor file
state_dict = pytorch_model.state_dict()
save_file(state_dict, str(model_path / "model.safetensors"))
def load_mlx_model(model_path: Path, mlx_config: dict) -> nn.Module:
"""
Load weights from safetensor file and initialize the MLX model.
"""
weight_files = glob.glob(str(model_path / "model*.safetensors"))
if not weight_files:
raise FileNotFoundError(f"No safetensors found in {model_path}")
# Load weights from safetensor file(s)
weights = {}
for wf in weight_files:
weights.update(mx.load(wf))
# Initialize MLX model
mlx_model = MLXDeepseekV2Attention(
MLXDeepseekV2Config(**mlx_config), layer_idx=0)
# Load weights into MLX model
print(weights.keys())
mlx_model.load_weights(list(weights.items()))
return mlx_model
class TestDeepseekV2Attention(unittest.TestCase):
def setUp(self):
self.config = {
"hidden_size": 512,
"num_attention_heads": 8,
"max_position_embeddings": 2048,
"rope_theta": 10000,
"attention_bias": True,
"q_lora_rank": None,
"kv_lora_rank": 32,
"qk_rope_head_dim": 32,
"v_head_dim": 64,
"qk_nope_head_dim": 32,
"attention_dropout": 0.1,
}
pytorch_config = PyTorchDeepseekV2Config(**self.config)
self.pytorch_model = PyTorchDeepseekV2Attention(
pytorch_config, layer_idx=0)
self.pytorch_model.eval()
self.temp_dir = Path("temp_model")
self.temp_dir.mkdir(exist_ok=True)
transfer_weights(self.pytorch_model, self.temp_dir)
self.mlx_model = load_mlx_model(self.temp_dir, self.config)
torch.manual_seed(0)
np.random.seed(0)
def tearDown(self):
for file in self.temp_dir.glob("*"):
file.unlink()
self.temp_dir.rmdir()
def test_forward_pass(self):
batch_size, seq_length = 2, 10
pytorch_input = torch.randn(
batch_size, seq_length, self.config["hidden_size"])
mlx_input = mx.array(pytorch_input.numpy())
pytorch_mask = torch.zeros(batch_size, 1, seq_length, seq_length)
mlx_mask = mx.zeros((batch_size, 1, seq_length, seq_length))
print("PyTorch input shape:", pytorch_input.shape)
print("MLX input shape:", mlx_input.shape)
with torch.no_grad():
pytorch_output, pytorch_attn_weights, pytorch_cache = self.pytorch_model(
pytorch_input, attention_mask=pytorch_mask)
mlx_output, mlx_attn_weights, mlx_cache = self.mlx_model(
mlx_input, mask=mlx_mask)
print("PyTorch output shape:", pytorch_output.shape)
print("MLX output shape:", mlx_output.shape)
pytorch_output_np = pytorch_output.numpy()
mlx_output_np = mlx_output
np.testing.assert_allclose(
pytorch_output_np, mlx_output_np, rtol=1e-4, atol=1e-4)
if __name__ == '__main__':
unittest.main()
This is supported now. Will do a release soon to update the pypi dist.
I tried using the
deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct
model and ran into this error:ValueError: Model type deepseek_v2 not supported.
Any plans to support
deepseek_v2
soon?