Closed Blaizzy closed 3 months ago
Usually the best way to debug these sorts of issues is to have a reference implementation (like the HF one) and the MLX implementation side-by-side.
Then write some wrapper to step through both implementations layer by layer and compare the activations. Keep zooming that in until you find the first place that the activations are substantially different. That should pin-point the part of the model file that is mismatched.
If you get to a point where the ops seems identical but are producing different results given the same inputs, then that is suggestive of a deeper issue in MLX. But in most cases the problem is some mis-match in config settings or operations in the high-level implementation.
Thanks!
I will do that and let you know.
For box/seg: are you sure you use our extended tokenizer with total vocab of 256000 + 1024 + 128?
For the « sorry… » it’s from safety tuning we had to do on the mix model. I don’t like it, and we tried hard to not over-trigger, but it’s not perfect. I think if you have some discrepancies in the model, this answer might appear more often than it should. I recommend doing what @awni said. Ideally comparing with our reference implementation in jax, but i believe the HF implementation was verified against it, so should be good too.
about 4bit, i have no idea.
Thanks a lot @lucasb-eyer!
How do I get the extended tokenizer? I don't see anything standing out in the huggingface implementation.
After investigating the model as @awni suggested I found that there was a big difference in vision model encoder activations between the two models, was over 100K. This was caused by using nn.Gelu(approx="fast")
on MLX, when I changed to approx="precise"
or implemented the transformers FastGELUActivation
the difference came down to ~3, with precise
having slightly lower scores overall. This is strange because, the original implementation uses approx="fast"
.
However, despite the lower activation difference between the two it still refuses to answer most natural language questions.
when I changed to approx="precise" or implemented the transformers FastGELUActivation the difference came down to ~3
What's that number mean? ~3 would be a large value for the max-abs-diff between two activations. Or is it like a sum over all the diffs?
Note there are three GELU options in MLX:
none
the full GELU using ERFprecise
a slightly faster approximationfast
also slightly faster less good approximationnone
and precise
correspond to the options in PyTorch. I'm not entirely sure what is meant byapprox="fast"
in your case but it might not match MLXs fast
so that is good to double check.
It's the sum over all the diffs.
np.abs(ref_model_layer - target_model_layer).sum()
Yes, I did. The orginal implementation of SigLip (paligemma's vision model) uses fast
and this setup works fine on a nearly identical model we currently support called NanoLlava, but for paligemma fast
approximation creates big divergence the activations.
By approx='fast'
, I mean the MLX configuration.
nn.GELU(approx='fast')
I event copied the transformers GELU activation in numpy to compare but I get similar to theprecise
approximation in MLX.
I'm not quite following the gelu story. But I think the safest call is to find the GELU implementation of the reference implementation (presumably the Jax code) and use that. You can check the MLX GELU implementations here to see if one matches.
I did exactly that.
Here are the implementations I tried. All of which are identical to the one used in transformer and JAX:
class FastGELUActivation(nn.Module):
"""
Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
"""
def __call__(self, input: mx.array) -> mx.array:
return 0.5 * input * (1.0 + mx.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
class FastGELUActivation(nn.Module):
"""
Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
"""
def __call__(self, input: mx.array) -> mx.array:
return 0.5 * input * (1.0 + mx.tanh(np.sqrt(2 / np.pi) * (input + 0.044715 * (input ** 3))))
Transformers: https://github.com/huggingface/transformers/blob/3d7d3a87a0bf4d0bb9346beb9419b1d76b5b988f/src/transformers/activations.py#L81 JAX: https://github.com/google/jax/blob/e93f36aa7c5cf329b517cd652777eb14ca35e8c0/jax/_src/nn/functions.py#L424
Yet, sum of abs-diff is close around 2.39 and 3.77 on the vision path. And the model still refuses a lot.
From the start till the first MLP everything is close to 0. The only part in which they start to differ significantly are the MLP with around 0.15 on the first vision encoder layer.
A layer before that is 0.08.
Here is my implementation of the MLP:
class MLP(nn.Module):
def __init__(self, config: VisionConfig):
super().__init__()
self.activation_fn = FastGELUActivation()
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=True)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=True)
def __call__(self, x: mx.array) -> mx.array:
x = self.fc1(x)
x = self.activation_fn(x)
x = self.fc2(x)
return x
Here is the transformers implementation:
class SiglipMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act] # uses same FastGELUActivation
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
I'm not sure what I am I missing here.
Let me go for a walk 🚶🏾♂️...
Any luck getting to the bottom of this?
FWIW it's expected there are some numerical differences in the MLX / PyTorch versions. Rather than looking at the sum (which is hard to reason about since it depends on the number of weights), maybe check something like a relative difference. Like ((x - y).abs() / x.abs()).max()
or some variation of that. It should be pretty small especially for float32
..
Not yet,
Yesterday, I tried using the huggingface VLM class in my implementation but that didn't change the results.
Let me check the relative distance and let you know.
@awni here are the results:
Language Model (Embedding output)
Relative Distance (using norms): 0.0
Max Absolute Relative Difference: 0.0
Are Matrices Close (np.allclose): True
Vision Model (Patch_embedding output):
Relative Distance (using norms): 5.3614764e-07
Max Absolute Relative Difference: 0.0685524
Are Matrices Close (np.allclose): False
Vision Model (Embeddings Layer output):
Relative Distance (using norms): 2.7392096e-07
Max Absolute Relative Difference: 0.05940594
Are Matrices Close (np.allclose): False
Vision Model (Encoder Layer 1 output):
Layer 1
Relative Distance (using norms): 1.8757571e-06
Max Absolute Relative Difference: 6.3747888
Are Matrices Close (np.allclose): False
Layer 2
Relative Distance (using norms): 2.0738366e-06
Max Absolute Relative Difference: 3.7936087
Are Matrices Close (np.allclose): False
Layer 3
Relative Distance (using norms): 1.9333681e-06
Max Absolute Relative Difference: 3.062366
Are Matrices Close (np.allclose): False
Vision model (Post layerNorm output):
Relative Distance (using norms): 2.5038335e-05
Max Absolute Relative Difference: 1.6712433
Are Matrices Close (np.allclose): False
Multi-modal projector (Linear layer output):
Relative Distance (using norms): 1.9119347e-05
Max Absolute Relative Difference: 3.1223333
Are Matrices Close (np.allclose): False
For context here is the model architecture:
PaliGemmaForConditionalGeneration(
(vision_tower): SiglipVisionModel(
(vision_model): SiglipVisionTransformer(
(embeddings): SiglipVisionEmbeddings(
(patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
(position_embedding): Embedding(256, 1152)
)
(encoder): SiglipEncoder(
(layers): ModuleList(
(0-26): 27 x SiglipEncoderLayer(
(self_attn): SiglipAttention(
(k_proj): Linear(in_features=1152, out_features=1152, bias=True)
(v_proj): Linear(in_features=1152, out_features=1152, bias=True)
(q_proj): Linear(in_features=1152, out_features=1152, bias=True)
(out_proj): Linear(in_features=1152, out_features=1152, bias=True)
)
(layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
(mlp): SiglipMLP(
(activation_fn): PytorchGELUTanh()
(fc1): Linear(in_features=1152, out_features=4304, bias=True)
(fc2): Linear(in_features=4304, out_features=1152, bias=True)
)
(layer_norm2): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
)
)
)
(post_layernorm): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
)
)
(multi_modal_projector): PaliGemmaMultiModalProjector(
(linear): Linear(in_features=1152, out_features=2048, bias=True)
)
(language_model): GemmaForCausalLM(
(model): GemmaModel(
(embed_tokens): Embedding(257216, 2048, padding_idx=0)
(layers): ModuleList(
(0-17): 18 x GemmaDecoderLayer(
(self_attn): GemmaSdpaAttention(
(q_proj): Linear(in_features=2048, out_features=2048, bias=False)
(k_proj): Linear(in_features=2048, out_features=256, bias=False)
(v_proj): Linear(in_features=2048, out_features=256, bias=False)
(o_proj): Linear(in_features=2048, out_features=2048, bias=False)
(rotary_emb): GemmaRotaryEmbedding()
)
(mlp): GemmaMLP(
(gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
(up_proj): Linear(in_features=2048, out_features=16384, bias=False)
(down_proj): Linear(in_features=16384, out_features=2048, bias=False)
(act_fn): PytorchGELUTanh()
)
(input_layernorm): GemmaRMSNorm()
(post_attention_layernorm): GemmaRMSNorm()
)
)
(norm): GemmaRMSNorm()
)
(lm_head): Linear(in_features=2048, out_features=257216, bias=False)
)
)
What are the formulas for these?
Relative Distance (using norms): 2.5038335e-05
Max Absolute Relative Difference: 1.6712433
Yesterday, I tried using the huggingface VLM class in my implementation but that didn't change the results.
Does that not suggest the problem is outside the model itself?
What are the formulas for these?
def relative_diff(x1, x2):
assert x1.shape == x2.shape, "Matrices must have the same dimensions"
if x1.ndim > 2 or x2.ndim > 2:
x1 = x1.reshape(-1)
x2 = x2.reshape(-1)
print("Relative Distance (using norms):", (np.linalg.norm(x1 - x2) / np.linalg.norm(x1)).max())
print("Max Absolute Relative Difference:", (abs(x1 - x2) / abs(x1)).max())
print("Are Matrices Close (np.allclose):", np.allclose(x1,x2))
Does that not suggest the problem is outside the model itself?
Yes, but where exactly? Because the language model is Gemma-2B and I double and triple checked it before and it works fine.
Language Model (only)
python -m mlx_vlm.generate --model google/paligemma-3b-pt-224 \
--prompt "Hi"
Prompt: Hi
I'm a very good friend of yours
==========
Prompt: 6.013 tokens-per-sec
Generation: 26.163 tokens-per-sec
Ok, after some deeper debugging.
I think the issue is in the multimodal feature merging and/or masking.
I'll update you once I have it working.
@awni @lucasb-eyer
I did everything by the book but the model still doesn't behave propely.
It seems like it behaves better only when using multimodal features from the transformers model. But that doesn't make sense because I have a 1:1 copy of that in MLX.
Could you please give this a look: https://github.com/Blaizzy/mlx-vlm/pull/24
@awni this weird behaviour also happened with Idefics2
in the past.
The only thing these have in common is that they are using F32 precision.
@awni any thoughts?
I couldn't say what the issue is.. I'll try to take a deeper look in the next few days.
Thanks! Looking forward to it :)
I just started digging in, but I think the problem may actually be in the mlx_vlm's LanguageModel implementation for PaliGemma. To demonstrate this, I replaced the mlx_vlm PaliGemma's LanguageModel with the corresponding implementation from Hugging Face Transformers in the code below.
from huggingface_hub import login
import os
login(token=os.getenv('HF_TOKEN'))
model_id = "google/paligemma-3b-mix-224"
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
prompt = 'Caption: '
import mlx.core as mx
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
import glob
from huggingface_hub import snapshot_download
import json
from PIL import Image
import requests
import torch
import importlib
import numpy as np
def sanitize(weights):
sanitized_weights = {}
for k, v in weights.items():
if "patch_embedding.weight" in k:
sanitized_weights[k] = v.transpose(0, 2, 3, 1)
else:
sanitized_weights[k] = v
return sanitized_weights
def load_model(model_id):
model_path = snapshot_download(
repo_id=model_id,
revision=None,
allow_patterns=[
"*.json",
"*.safetensors",
"*.py",
"tokenizer.model",
"*.tiktoken",
"*.txt",
],
)
with open(f"{model_path}/config.json", "r") as f:
config = json.load(f)
weights = {}
weight_files = glob.glob(str(f"{model_path}/*.safetensors"))
for wf in weight_files:
weights.update(mx.load(wf))
weights = sanitize(weights)
model_class = importlib.import_module(f"mlx_vlm.models.paligemma")
model_config = model_class.ModelConfig.from_dict(config)
model_config.vision_config = model_class.VisionConfig.from_dict(config["vision_config"])
model_config.text_config = model_class.TextConfig.from_dict(config["text_config"])
model = model_class.Model(model_config)
model.load_weights(list(weights.items()))
mx.eval(model.parameters())
model.eval()
return model
model_mx = load_model(model_id)
processor = AutoProcessor.from_pretrained(model_id)
prompt_tokens = mx.array(processor.tokenizer.encode(prompt))
inputs = processor(prompt, Image.open(requests.get(img_url, stream=True).raw), return_tensors="np")
pixel_values = mx.array(inputs["pixel_values"])
input_ids = mx.array(inputs["input_ids"])
mask = mx.array(inputs["attention_mask"])
inputs_embeds = model_mx.language_model.model.embed_tokens(input_ids)
hidden_state, _, _ = model_mx.vision_tower(
pixel_values.transpose(0, 2, 3, 1).astype(inputs_embeds.dtype),
output_hidden_states=True,
)
image_features = hidden_state[None, :].astype(pixel_values.dtype)
image_features = model_mx.multi_modal_projector(image_features)
input_embeddings, final_attention_mask_4d = (
model_mx._prepare_inputs_for_multimodal(
image_features, inputs_embeds, input_ids, mask
)
)
# # `<<< mx language
# logits, cache = model_mx.language_model(
# inputs=input_ids,
# cache=None,
# inputs_embeds=input_embeddings,
# mask=final_attention_mask_4d,
# )
# # `>>> mx language
# `<<< hf language
model_hf = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval()
final_attention_mask_4d = torch.from_numpy(np.array(final_attention_mask_4d, dtype=np.float32))
input_embeddings = torch.from_numpy(np.array(input_embeddings))
outputs = model_hf.language_model(
attention_mask=final_attention_mask_4d,
position_ids=None,
past_key_values=None,
inputs_embeds=input_embeddings,
use_cache=False,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
cache_position=None,
)
logits = outputs.logits
logits = mx.array(logits.detach().numpy())
# `>>> hf language
logits = logits[:, -1, :]
token = mx.argmax(logits, axis=-1)
print(token, processor.tokenizer.decode(token.tolist()))
The modification immediately improves the output text quality. Output from the commented part of the above code (the mlx_vlm's version of LanguageMoel) is array([12156], dtype=uint32) Sorry
(the full output from the mlx_vlm is "Sorry, as a base VLM I am not trained to answer this question.") while that from the huggingface's one is array([886], dtype=uint32) In
(the full output from huggingface's PaliGemma is "In this image we can see a car on the road. In the background there is a wall, door, trees and sky.")
Thanks @JosefAlbers!
Found the bug and fixed it :)
@awni @lucasb-eyer it's fixed ✅
After my changes, I didn't update the gemma embedding scaling to all inputs (text and multimodal). It was only scaling text embeddings.
That's why when I unit tested the language model it worked but failed with multimodal.
@JosefAlbers could you share your X handle ?
I want to tag you on the release :)
Great, I'm just glad I could help!
@awni I have PaliGemma working on MLX. In most cases works great.
But there 4 issues, I don't see in the transformers implementation:
The last two work fine on transformers.
The archictecture is correct and the model does work but it seems like MLX is either changing the precision or something that makes the model not behave 100% normal.
Resources: