Short answer: The results you're getting are expected because the implementation isn't 100% deterministic.
Now, please forgive the following rant, but since this question comes up often I may as well explain it in detail so I have something to refer to next time.
What it comes down to is the use of atomic operations in the GEMV-oriented matmul kernels, the fact that CUDA's thread launch order is nondeterministic and the fact that floating-point addition is non-associative. Essentially (a+b)+c != a+(b+c) for floating point types, and if you compute a+b+c across threads in CUDA using atomic addition, you can't know if the result is going to be (a+b)+c or a+(b+c). The two results are often the same but they can sometimes differ by a tiny amount due to rounding of the intermediate result.
Now, after accumulating millions and millions of those tiny errors, one forward pass might produce a distribution of [0.700, 0.104, 0.102...] while the same forward pass next time around could give you [0.702, 0.101, 0.100...]. Both would be correct to the precision of FP16, but they wouldn't be identical bit-for-bit.
Then, all you need is one binomial sample that's right on the boundary between two tokens, say, sampling to a cutoff of 0.701 in that example, and your output sequences are now different and diverging. Or you could end up with distributions where the top two tokens are practically the same probability, and the random jitter flips their order, so even greedy sampling can give diverging results. This seems likely in your example since the divergence starts at the point of Artificial intelligence (AI) is followed by either a or the, both of which seem like they would be highly preferred tokens, quite possibly scoring roughly the same and becoming a "tipping point" in the autoregressive sequence.
As for what to do about it, well, you can either do inference in FP32 while rounding to FP16 at deterministic points along the way, hoping that the FP32 error never accumulates to the point that it would flip a bit in the nearest FP16 value, or you can use deterministic (reduction-based) matmul approaches. Both of those hurt performance, so there's no immediate support for them in ExLlama.
The third option is to question the need for determinism in the first place. Consider this:
If you're already using quantized models it's clearly not perfect precision you're after. The problem with nondeterminism can't be that it doesn't give you "the precisely correct result" because that went out the window anyway when you distilled the model down to the most important 25% of its bits.
If you get two slightly different outputs from the same input, there's no sense in which either of the two outputs is the correct one, unless you arbitrarily define an order of additions to be canonical and then write/select CUDA kernels to always follow that order. But then the resulting output is still only correct by an arbitrary definition which depends on specifics you're probably not concerned with, such as which kernel (or set of kernels) you end up launching by calling torch.matmul() in Python or cublasHgemm in C++. It can even change with your CUDA version or your hardware architecture. From the cuBLAS API docs:
By design, all cuBLAS API routines from a given toolkit version, generate the same bit-wise results at every run when executed on GPUs with the same architecture and the same number of SMs. However, bit-wise reproducibility is not guaranteed across toolkit versions because the implementation might differ due to some implementation changes.
What is it that you actually achieve by passing the same input twice instead of just duplicating the output? Consider if the implementation secretly did this to "fake" being deterministic, i.e., imagine it cached the result of each forward pass using a database of hashed inputs and whatever was the first output non-deterministically computed for each of them. Is there any conceivable way you could detect that the determinism was "faked" in this way, if you just had the outputs to go by?
Personally, I would argue that good testing methodologies in this context need to be robust to noise anyway. Yes, bitwise-identical outputs are a useful proxy for functionally identical implementations, as a way to quickly verify that you didn't make any subtle mistakes by having a huge, chaotic-dynamic computation arrive at the exact same result as a reference implementation.
But it doesn't tell you anything more than that. As soon as you switch to a different GPU, or you split your tensors across multiple GPUs, or you update to a new version of PyTorch, your could get any other output from the set of possible outputs that are all correct to within the precision of FP16. If you're getting Artificial intelligence (AI) is the field of study... on your system, a user of whatever application you're building might still see Artificial intelligence (AI) is a field of computer science... with the exact same prompt, only because they're two weeks behind on nightly builds of PyTorch.
And even within a single, deterministic context, causal LM is still causal which makes it is chaotic and sensitive to minute changes in initial conditions: Try the same sequence with an extra space at the beginning and watch it take a completely different turn regardless of precision and determinism in the underlying framework.
All in all, while it would be nice to add determinism as an option, it would either degrade performance just for the sake of test cases like yours, or it would be a switchable option that doesn't actually prove anything about the correctness of the other, non-deterministic code path.
Again, sorry for the rant. But since your observation is correct and your concern is entirely valid I feel it deserves a full explanation. I'd happily entertain discussion if anyone should disagree.
Now, please forgive the following rant, but since this question comes up often I may as well explain it in detail so I have something to refer to next time.
What it comes down to is the use of atomic operations in the GEMV-oriented matmul kernels, the fact that CUDA's thread launch order is nondeterministic and the fact that floating-point addition is non-associative. Essentially (a+b)+c != a+(b+c) for floating point types, and if you compute a+b+c across threads in CUDA using atomic addition, you can't know if the result is going to be (a+b)+c or a+(b+c). The two results are often the same but they can sometimes differ by a tiny amount due to rounding of the intermediate result.
Now, after accumulating millions and millions of those tiny errors, one forward pass might produce a distribution of [0.700, 0.104, 0.102...] while the same forward pass next time around could give you [0.702, 0.101, 0.100...]. Both would be correct to the precision of FP16, but they wouldn't be identical bit-for-bit.
Then, all you need is one binomial sample that's right on the boundary between two tokens, say, sampling to a cutoff of 0.701 in that example, and your output sequences are now different and diverging. Or you could end up with distributions where the top two tokens are practically the same probability, and the random jitter flips their order, so even greedy sampling can give diverging results. This seems likely in your example since the divergence starts at the point of
Artificial intelligence (AI) is
followed by eithera
orthe
, both of which seem like they would be highly preferred tokens, quite possibly scoring roughly the same and becoming a "tipping point" in the autoregressive sequence.As for what to do about it, well, you can either do inference in FP32 while rounding to FP16 at deterministic points along the way, hoping that the FP32 error never accumulates to the point that it would flip a bit in the nearest FP16 value, or you can use deterministic (reduction-based) matmul approaches. Both of those hurt performance, so there's no immediate support for them in ExLlama.
The third option is to question the need for determinism in the first place. Consider this:
If you're already using quantized models it's clearly not perfect precision you're after. The problem with nondeterminism can't be that it doesn't give you "the precisely correct result" because that went out the window anyway when you distilled the model down to the most important 25% of its bits.
If you get two slightly different outputs from the same input, there's no sense in which either of the two outputs is the correct one, unless you arbitrarily define an order of additions to be canonical and then write/select CUDA kernels to always follow that order. But then the resulting output is still only correct by an arbitrary definition which depends on specifics you're probably not concerned with, such as which kernel (or set of kernels) you end up launching by calling
torch.matmul()
in Python orcublasHgemm
in C++. It can even change with your CUDA version or your hardware architecture. From the cuBLAS API docs:Personally, I would argue that good testing methodologies in this context need to be robust to noise anyway. Yes, bitwise-identical outputs are a useful proxy for functionally identical implementations, as a way to quickly verify that you didn't make any subtle mistakes by having a huge, chaotic-dynamic computation arrive at the exact same result as a reference implementation.
But it doesn't tell you anything more than that. As soon as you switch to a different GPU, or you split your tensors across multiple GPUs, or you update to a new version of PyTorch, your could get any other output from the set of possible outputs that are all correct to within the precision of FP16. If you're getting
Artificial intelligence (AI) is the field of study...
on your system, a user of whatever application you're building might still seeArtificial intelligence (AI) is a field of computer science...
with the exact same prompt, only because they're two weeks behind on nightly builds of PyTorch.And even within a single, deterministic context, causal LM is still causal which makes it is chaotic and sensitive to minute changes in initial conditions: Try the same sequence with an extra space at the beginning and watch it take a completely different turn regardless of precision and determinism in the underlying framework.
All in all, while it would be nice to add determinism as an option, it would either degrade performance just for the sake of test cases like yours, or it would be a switchable option that doesn't actually prove anything about the correctness of the other, non-deterministic code path.
Again, sorry for the rant. But since your observation is correct and your concern is entirely valid I feel it deserves a full explanation. I'd happily entertain discussion if anyone should disagree.
Originally posted by @turboderp in https://github.com/turboderp/exllamav2/issues/232#issuecomment-1860896496