Open chengchingwen opened 3 days ago
adding wait_completed
on matmul!
's command buffer does not help
Adding Metal.@sync
to the mul!
also does not help. ~However, I cannot reproduce when calling MPS.matmul!
directly.~
I cannot reproduce at all on Metal.jl#master using an M3 Pro, but it does seem reproducible on an M1 Pro.
I wonder if this is a problem with mapreduce
, since you're calling isapprox
on GPU arrays. Can you test if calling @assert Array(C) ≈ Array(c)
makes things pass? It does here, at least.
I can reproduce the issue on M1 master. It also looks like all the tasks run on the same queue.
The issue was found on a M2 Max. The MWE only happens if the array is large enough. It seems to be launching the subsequent kernel before the matmul finished. Is it possible that the mapreduce
not checking the availability of the input arrays?
p.s. I'm about to board the plane to JuliaCon so I won't be able to test it soon.
I wonder if this is a problem with
mapreduce
, since you're callingisapprox
on GPU arrays. Can you test if calling@assert Array(C) ≈ Array(c)
makes things pass? It does here, at least.
It also reproduces when comparing on the CPU, just much less likely, so this isn't a mapreduce
issue.
Looks like a bunch of NaN's in the second matrix.
My current MWE is:
using Metal, LinearAlgebra; begin
n = 10000
a = mtl(randn(Float32,n,n))
b = mtl(randn(Float32,n,n))
C = Metal.zeros(Float32, size(a))
for i in 1:10
C = Metal.zeros(Float32, size(a))
mul!(C,a,b)
@assert !any(isnan.(C)) "$i"
end
end
I define C out of the loop to access it afterwards. When I had C .= ...
in the loop instead of C = ...
. It only ever happened at iteration 1. I suspect it has to do with the location in memory of the array.
I cannot reproduce when calling
MPS.matmul!
directly
I can:
using Metal, LinearAlgebra
function main(T=Float32, N=10000)
a = Metal.rand(T, N, N)
b = Metal.rand(T, N, N)
c = a * b'
synchronize()
for i in 1:100
println("Iteration $i")
d = Metal.zeros(T, size(a))
MPS.matmul!(d, a, b, #=alpha=#true, #=beta=#false,
#=transpose_a=#false, #=transpose_b=#true)
@assert !any(isnan.(Array(d))) "NaN in iteration $i"
# XXX: this redundant check is needed, or the failure never occurs
@assert !any(isnan.(d))
end
end
isinteractive() || main()
The need for a secondary kernel is very weird.
It is not MPS related:
for i in 1:10
C = Metal.zeros(Float32, size(a))
GPUArrays.generic_matmatmul!(C, a, b, MulAddMul())
@assert C ≈ c "$i"
end
GPUArrays.generic_matmatmul!(C, a, b, MulAddMul())
I don't see how that's related; it's an entirely different kernel. Does it contain NaNs in similar places? The generic matmatmul kernel, while being extraordinarily slow, doesn't introduce NaNs here.
Just wanted to confirm that its MPS rather than the synchronisation between kernel launches.
I've been seeing the NaN issues with large arrays for a long time in #145
MPX seems fine:
import mlx.core as mx
a = mx.random.normal((10000, 10000))
b = mx.random.normal((10000, 10000))
c = a @ b.T
for i in range(0,10):
C = a @ b.T
assert(mx.allclose(C,c))
MWE: