JuliaGPU / Metal.jl

Metal programming in Julia
MIT License
339 stars 36 forks source link

M1/M2: Large matrix multiplications can contains NaNs #381

Open chengchingwen opened 3 days ago

chengchingwen commented 3 days ago

MWE:

julia> a = Metal.randn(10000, 10000);

julia> b = Metal.randn(10000, 10000);

julia> c = a * b';

julia> for i in 1:10
           C = Metal.zeros(Float32, size(a))
           mul!(C, a, b')
           @assert C ≈ c "$i"
       end
ERROR: AssertionError: 1
Stacktrace:
 [1] top-level scope
   @ ./REPL[58]:4

julia> for i in 1:10
           C = Metal.zeros(Float32, size(a))
           mul!(C, a, b')
           @assert C ≈ c "$i"
       end
ERROR: AssertionError: 8
Stacktrace:
 [1] top-level scope
   @ ./REPL[58]:4

julia> for i in 1:10
           @assert a * b' ≈ c "$i"
       end
ERROR: AssertionError: 3
Stacktrace:
 [1] top-level scope
   @ ./REPL[59]:2

julia> for i in 1:10
           @assert a * b' ≈ c "$i"
       end
ERROR: AssertionError: 8
Stacktrace:
 [1] top-level scope
   @ ./REPL[59]:2
chengchingwen commented 3 days ago

adding wait_completed on matmul!'s command buffer does not help

christiangnrd commented 3 days ago

Adding Metal.@sync to the mul! also does not help. ~However, I cannot reproduce when calling MPS.matmul! directly.~

maleadt commented 2 days ago

I cannot reproduce at all on Metal.jl#master using an M3 Pro, but it does seem reproducible on an M1 Pro.

I wonder if this is a problem with mapreduce, since you're calling isapprox on GPU arrays. Can you test if calling @assert Array(C) ≈ Array(c) makes things pass? It does here, at least.

tgymnich commented 2 days ago

I can reproduce the issue on M1 master. It also looks like all the tasks run on the same queue.

chengchingwen commented 2 days ago

The issue was found on a M2 Max. The MWE only happens if the array is large enough. It seems to be launching the subsequent kernel before the matmul finished. Is it possible that the mapreduce not checking the availability of the input arrays?

p.s. I'm about to board the plane to JuliaCon so I won't be able to test it soon.

maleadt commented 2 days ago

I wonder if this is a problem with mapreduce, since you're calling isapprox on GPU arrays. Can you test if calling @assert Array(C) ≈ Array(c) makes things pass? It does here, at least.

It also reproduces when comparing on the CPU, just much less likely, so this isn't a mapreduce issue.

maleadt commented 2 days ago

Looks like a bunch of NaN's in the second matrix.

christiangnrd commented 2 days ago

My current MWE is:

using Metal, LinearAlgebra; begin
    n = 10000
    a = mtl(randn(Float32,n,n))
    b = mtl(randn(Float32,n,n))
    C = Metal.zeros(Float32, size(a))
    for i in 1:10
        C = Metal.zeros(Float32, size(a))
        mul!(C,a,b)
        @assert !any(isnan.(C)) "$i"
    end
end

I define C out of the loop to access it afterwards. When I had C .= ... in the loop instead of C = .... It only ever happened at iteration 1. I suspect it has to do with the location in memory of the array.

maleadt commented 2 days ago

I cannot reproduce when calling MPS.matmul! directly

I can:

using Metal, LinearAlgebra

function main(T=Float32, N=10000)
    a = Metal.rand(T, N, N)
    b = Metal.rand(T, N, N)
    c = a * b'
    synchronize()

    for i in 1:100
        println("Iteration $i")
        d = Metal.zeros(T, size(a))
        MPS.matmul!(d, a, b, #=alpha=#true, #=beta=#false,
                    #=transpose_a=#false, #=transpose_b=#true)
        @assert !any(isnan.(Array(d))) "NaN in iteration $i"

        # XXX: this redundant check is needed, or the failure never occurs
        @assert !any(isnan.(d))
    end
end

isinteractive() || main()

The need for a secondary kernel is very weird.

tgymnich commented 2 days ago

It is not MPS related:

 for i in 1:10
       C = Metal.zeros(Float32, size(a))
       GPUArrays.generic_matmatmul!(C, a, b, MulAddMul())
       @assert C ≈ c "$i"
end
maleadt commented 2 days ago
GPUArrays.generic_matmatmul!(C, a, b, MulAddMul())

I don't see how that's related; it's an entirely different kernel. Does it contain NaNs in similar places? The generic matmatmul kernel, while being extraordinarily slow, doesn't introduce NaNs here.

tgymnich commented 2 days ago

Just wanted to confirm that its MPS rather than the synchronisation between kernel launches.

tgymnich commented 2 days ago

I've been seeing the NaN issues with large arrays for a long time in #145

MPX seems fine:

import mlx.core as mx

a = mx.random.normal((10000, 10000))
b = mx.random.normal((10000, 10000))
c = a @ b.T

for i in range(0,10):
    C = a @ b.T
    assert(mx.allclose(C,c))