Closed jeromeku closed 3 months ago
I observed similar difference in logits. There might be some slight precision issue in the kernels.
@casper-hansen
Running above script gives max abs error: .1848.
Also interested in running scattermoe with pretrained models per your PR #2.
I have a working implementation of converting pretrained models to scattermoe but the precision is a major concern (or just a misstep in my implementation).
Working on the pretrained models right now.
I am observing the same thing, likely cause is the float32 router weights in the original implementation vs the casting to bfloat16 I did in my modification.
@shawntan thanks for the insight, potentially this could be quickly verified right. Define and load the router in fp32. Also it seems that there's no need to cast back to bf16 as router only decides which experts will be activated. The experts' actual computations are still in bf16 anyhow. A conclusion on the discrepancy can be really beneficial.
I checked, the discrepancy is about 0.01 in float32. I'm still studying the problem.
@shawntan
Can you point me to the script reproducing that result? Happy to help debug.
I checked, the discrepancy is about 0.01 in float32. I'm still studying the problem.
@shawntan Where is the script for reproducing this?
I checked, the discrepancy is about 0.01 in float32. I'm still studying the problem.
@shawntan Where is the script for reproducing this?
import torch
import transformers
import gc
from mixtral.modeling_mixtral import MixtralModel, MixtralForCausalLM, MixtralSparseMoeBlock
from mixtral.configuration_mixtral import MixtralConfig
MODEL_NAME = "mistralai/Mixtral-8x7B-v0.1"
if __name__ == "__main__":
device = torch.device('cuda:0')
dtype = torch.float32
config = transformers.MixtralConfig.from_pretrained(MODEL_NAME)
num_experts = config.num_local_experts
x = torch.randn(8, 2048, 4096).to(dtype=dtype, device=device)
sm_mlp = MixtralSparseMoeBlock(config).to(dtype=dtype, device=device)
hf_mlp = transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock(config).to(dtype=dtype, device=device)
hf_state_dict = hf_mlp.state_dict()
for n, p in sm_mlp.named_parameters():
if n in hf_state_dict:
p.data[:] = hf_state_dict.pop(n)
else:
prefix, suffix = n.split('moe_mlp')
for i in range(num_experts):
if suffix == ".output_experts.weight":
w2_param_name = prefix + "experts.%d.w2.weight" % i
p.data[i, :, :] = hf_state_dict.pop(w2_param_name)
else:
w1_param_name = prefix + "experts.%d.w1.weight" % i
w3_param_name = prefix + "experts.%d.w3.weight" % i
out_dim, in_dim = hf_state_dict[w1_param_name].size()
p.data[i, :out_dim, :] = hf_state_dict.pop(w3_param_name)
p.data[i, out_dim:, :] = hf_state_dict.pop(w1_param_name)
hf_out = hf_mlp(x)[0]
sm_out = sm_mlp(x)[0]
diff = torch.abs(hf_out - sm_out)
print("max diff:", diff.max())
I think the issue was the way I did the gating inside GLUMLP, your w3 and w1 concatenation order should be swapped. Sorry about that.
Sorry for the erratic behaviour, I was trying to fix some really weird behaviour when running on lm-eval, and I still haven't figured it out. At first I thought the NaN issues I was getting was due to the precision problems in the kernel, so I was trying all manner of fixes to ensure the numerical errors were as low as possible. Turns out turning off tf32 makes it close to 0, but leaving it on isn't too bad either. It is a known effect of tf32 after all.
The issue seems to be weirder than that.
Asking here in case anyone knows: When running parallelized, everything happening on cuda:0
goes well, then on handing over to cuda:1
, it's like the kernel never executes, the empty buffer I initialised gets passed back to me. If I switch it to torch.zeros
instead of torch.empty
, I get a zero buffer.
@shawntan I have a bit of experience with things not going as expected when going from device 0 to 1. Usually, it's due to a lack of device guarding or lack of passing everything to the right device. However, these issues are incredibly hard to actually figure out and doing any print statements usually results in incorrectly reported values. That's my 2 cents on this issue.
Can you try using a with block in some of your modeling code to make sure?
with torch.cuda.device(hidden_states.device):
For reference, here is how AutoGPTQ calls it's Triton kernels: https://github.com/AutoGPTQ/AutoGPTQ/blob/74212f501396827772bf9d915b079bd6a419bd92/auto_gptq/nn_modules/triton_utils/kernels.py#L347-L374
It works! Thanks @casper-hansen , great tip. I'll keep that in mind in future.
@shawntan
Issue
Testing
scattermoe.mlp.GLUMLP
against reference HFMixtralSparseMLP
seems to give significant discrepancies in output.Repro
MixtralSparseMoeBlock
fromtransformers.models.mixtral.modeling_mixtral
and instantiated it with a basic config (AutoConfig.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1"
).GLUMLP
by stacking the weights of each expert into theexpert
andoutput_expert
ParallelExperts
weights.routing_weights
andexpert_idx
using the reference HF implementation.hidden_states
,routing_weights
, andexpert_idx
to theexperts
layer of the referenceMixtralSparseMoeBlock
and toGLUMLP
.For
batch size=1
,sequence length=5
, andhidden dim=4096
, this gives errors on the order of.1 - .5
forfloat16 / float32
.Env:
torch
: 2.2.2triton
: 2.2.0A6000
Full script: