FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.52k stars 608 forks source link

Usage of OneHotMatrix for input to neural network is very slow. #1355

Closed racinmat closed 3 years ago

racinmat commented 4 years ago

Multiplication of OneHotMatrix by dense layer could be more optimized, e.g. by

using LinearAlgebra
function Base.:*(A::AbstractMatrix, B::Flux.OneHotMatrix)
    m = size(A,1)
    Y = similar(A, m, size(B,2))
    for (j,ohv) in enumerate(B.data)
        ix = ohv.ix
        for i in 1:m
            @inbounds Y[i,j] = A[i,ix]
        end
    end
    Y
end
function Base.:*(A::AbstractMatrix, B::Adjoint{Bool,<: Flux.OneHotMatrix})
    m = size(A,1)
    Y = similar(A, m, size(B,2))
    Y .= 0
    BT = B'
    for (j,ohv) in enumerate(BT.data)
        ix = ohv.ix
        for i in 1:m
            @inbounds Y[i,ix] += A[i,j]
        end
    end
    Y
end

should I make PR for this?

DhairyaLGandhi commented 4 years ago

Yes that would be awesome. Somewhat like Zeros, we can have a method dispatch to get the correct columns from the hot neuron.

racinmat commented 4 years ago

Could you elaborate more what exactly do you have in mind?

CarloLucibello commented 3 years ago

@racinmat did you submit a PR for this?

racinmat commented 3 years ago

Not yet, but I plan to do it.

ToucheSir commented 3 years ago

I think this is resolved by https://github.com/FluxML/Flux.jl/pull/1448 and https://github.com/FluxML/Flux.jl/pull/1424?

racinmat commented 3 years ago

In the end it looks like it has not been resolved, here is benchmark: https://github.com/racinmat/flux_benchmarks/blob/master/results_slurm_1.md I'll make new PR for it.

DhairyaLGandhi commented 3 years ago

What does the tricks imply there?

racinmat commented 3 years ago

That's the faster implementation: https://github.com/racinmat/flux_benchmarks/blob/master/0_12_7_tricks/main.jl#L6-L28 although I guess I should modify it for the PR so it would dispatch only on OneHotArray of dimension 2, right?

ToucheSir commented 3 years ago

The problem with the linked implementations is that they are decidedly not GPU-friendly. Given that the current code path is literally a property lookup and vectorized index, I think it would be fruitful to profile that first and see where the bottlenecks are.

racinmat commented 3 years ago

That's true, we would need to dispatch to different implementation for CuArrays, right? It would definitely be fuitful to see why current implementations are that slow.

ToucheSir commented 3 years ago

Well, that's the thing. The current implementation should be more than fast enough since it theoretically does less work than the custom dense CPU array only version with a loop. That it doesn't warrants an investigation. Ideally, we'd like to avoid writing custom kernels for something so trivial.

DhairyaLGandhi commented 3 years ago

Right, maybe the we can find out where the bottleneck is in the current implementation instead

racinmat commented 3 years ago

I might have a culprit, https://github.com/FluxML/Flux.jl/blob/master/src/onehot.jl#L229 dispatches on OneHotVector, not OneHotMatrix, and multiplication by OneHotMatrix defaults to https://github.com/JuliaLang/julia/blob/v1.6.3/stdlib/LinearAlgebra/src/matmul.jl#L151-L154

CarloLucibello commented 3 years ago

I think multiplication by onehot can be expressed as a gather operation, for which we have cuda kernels

ToucheSir commented 3 years ago

I might have a culprit, https://github.com/FluxML/Flux.jl/blob/master/src/onehot.jl#L229 dispatches on OneHotVector, not OneHotMatrix, and multiplication by OneHotMatrix defaults to https://github.com/JuliaLang/julia/blob/v1.6.3/stdlib/LinearAlgebra/src/matmul.jl#L151-L154

The OneHotMatrix path uses https://github.com/FluxML/Flux.jl/blob/master/src/onehot.jl#L222, not the base fallback.

Edit: this (and I) were wrong, see post below.

ToucheSir commented 3 years ago

I just noticed the benchmark script doesn't use interpolation for the global variables. Here are the results on my local machine:

julia> @btime $x*$y;
  6.848 μs (3 allocations: 40.02 KiB)

julia> @btime fast_mul($x, $y);
  3.010 μs (2 allocations: 39.14 KiB)

julia> @btime $x*$y';
  738.998 μs (8 allocations: 39.47 KiB)

julia> @btime fast_mul($x, $y');
  4.405 μs (2 allocations: 39.14 KiB)

Where fast_mul is just the mul implementation linked above as a standalone function.

Edit: mea culpa, here is the timing for x*y' after treating it like a wrapper type in https://github.com/FluxML/Flux.jl/blob/master/src/onehot.jl#L31-L33:

julia> @btime $x*$y';
  30.439 μs (19 allocations: 42.25 KiB)
DhairyaLGandhi commented 3 years ago

Seems fine to me. Maybe we need to make the wrapper types transparent, but that's a Julia compiler thing. We should however make sure not to hit generic mul.

DhairyaLGandhi commented 3 years ago

Best not to hack in mul with gather.

racinmat commented 3 years ago

So how do we want to solve it? Should I play with the OneHotVector multiplication and extend it to multiplication by OneHotMatrix? By the way, I found out currently we can't multiply adjoint vector by onehot matrix:

julia>   v = [1, 2, 3, 4, 5]
5-element Vector{Int64}:
 1
 2
 3
 4
 5

julia>   b2 = Flux.OneHotMatrix([2, 4, 1, 3], 5)
5×4 OneHotMatrix(::Vector{Int64}) with eltype Bool:
 ⋅  ⋅  1  ⋅
 1  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  1
 ⋅  1  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅

julia>         v' * b2
ERROR: LoadError: MethodError: *(::Adjoint{Int64, Vector{Int64}}, ::OneHotArray{Int64, 5, 1, 2, Vector{Int64}}) is ambiguous. Candidates:
  *(A::AbstractMatrix{T} where T, B::Union{OneHotArray{var"#s157", L, N, var"N+1", I}, Base.ReshapedArray{Bool, var"N+1", var"#s1571", MI} where {var"#s1571"<:(OneHotArray{var"#s157", L, var"#s156", var"#s155", I} where {var"#s156", var"#s155"}), MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}} where {var"#s157", N, var"N+1", I}) where L in Flux at E:\Projects\others_code\Flux.jl\src\onehot.jl:223
  *(x::Adjoint{T, var"#s832"} where {T, var"#s832"<:(AbstractVector{T} where T)}, A::AbstractMatrix{T} where T) in LinearAlgebra at C:\Users\Azathoth\AppData\Local\Programs\Julia-1.6.0\share\julia\stdlib\v1.6\LinearAlgebra\src\matmul.jl:133
Possible fix, define
  *(::Adjoint{T, var"#s832"} where {T, var"#s832"<:(AbstractVector{T} where T)}, ::Union{Base.ReshapedArray{Bool, 2, var"#s1571", MI} where {var"#s157", I, var"#s1571"<:(OneHotArray{var"#s157", L, var"#s156", var"#s155", I} where {var"#s156", var"#s155"}), MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, OneHotArray{var"#s157", L, N, 2, I} where {var"#s157", N, I}}) where L
CarloLucibello commented 3 years ago

Which cases are we benchmarking? Can't find the script

Il lun 25 ott 2021, 21:12 Matěj Račinský @.***> ha scritto:

So how do we want to solve it? Should I play with the OneHotVector multiplication and extend it to multiplication by OneHotMatrix? By the way, I found out currently we can't multiply adjoint vector by onehot matrix:

julia> v = [1, 2, 3, 4, 5] 5-element Vector{Int64}:

1

2

3

4

5

julia> b2 = Flux.OneHotMatrix([2, 4, 1, 3], 5) 5×4 OneHotMatrix(::Vector{Int64}) with eltype Bool:

⋅ ⋅ 1 ⋅

1 ⋅ ⋅ ⋅

⋅ ⋅ ⋅ 1

⋅ 1 ⋅ ⋅

⋅ ⋅ ⋅ ⋅

julia> v' * b2

ERROR: LoadError: MethodError: *(::Adjoint{Int64, Vector{Int64}}, ::OneHotArray{Int64, 5, 1, 2, Vector{Int64}}) is ambiguous. Candidates:

*(A::AbstractMatrix{T} where T, B::Union{OneHotArray{var"#s157", L, N, var"N+1", I}, Base.ReshapedArray{Bool, var"N+1", var"#s1571", MI} where {var"#s1571"<:(OneHotArray{var"#s157", L, var"#s156", var"#s155", I} where {var"#s156", var"#s155"}), MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}} where {var"#s157", N, var"N+1", I}) where L in Flux at E:\Projects\others_code\Flux.jl\src\onehot.jl:223

*(x::Adjoint{T, var"#s832"} where {T, var"#s832"<:(AbstractVector{T} where T)}, A::AbstractMatrix{T} where T) in LinearAlgebra at C:\Users\Azathoth\AppData\Local\Programs\Julia-1.6.0\share\julia\stdlib\v1.6\LinearAlgebra\src\matmul.jl:133

Possible fix, define

*(::Adjoint{T, var"#s832"} where {T, var"#s832"<:(AbstractVector{T} where T)}, ::Union{Base.ReshapedArray{Bool, 2, var"#s1571", MI} where {var"#s157", I, var"#s1571"<:(OneHotArray{var"#s157", L, var"#s156", var"#s155", I} where {var"#s156", var"#s155"}), MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, OneHotArray{var"#s157", L, N, 2, I} where {var"#s157", N, I}}) where L

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/FluxML/Flux.jl/issues/1355#issuecomment-951226831, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABVQOQTBAUPIEAT72HG3DIDUIWTY5ANCNFSM4SR3Z5PA .

racinmat commented 3 years ago

Sorry, that was my attempt to modify https://github.com/FluxML/Flux.jl/blob/master/test/onehot.jl#L28-L42 testcase to test multiplication by OneHotMatrix instead of OneHotVector, because I noticed Flux covers multiplication by OneHotVector, but not by matrices.

racinmat commented 3 years ago

But getting back to my benchmarks and benchmarking the original code:

julia> @btime onecold($y);
  108.923 ns (1 allocation: 896 bytes)

julia> idx = onecold(y);

julia> @btime $x[:, $idx];
  6.500 μs (2 allocations: 39.14 KiB)

it seems the vectorized indexing is twice as slow as the fast_mul I proposed.

And regarding the $x*$y' I don't know how to implement it just using onecold without the summation. And this is almost 6x slower than the proposed version. I recognize the proposed version is gpu-unfriendly, but I don't know how to write it in GPU-friendly way other than dispatching to some GPU-friendly version for GPU and keeping the GPU-unfriendly for CPU. Or is the speedup not worth the specialized implementation?

CarloLucibello commented 3 years ago

But getting back to my benchmarks and benchmarking the original code:

I mean these, where are these benchmarks?

racinmat commented 3 years ago

Sorry, it's these https://github.com/racinmat/flux_benchmarks/blob/master/0_12_7/main.jl https://github.com/racinmat/flux_benchmarks/blob/master/0_12_7_tricks/main.jl I compare more versions there, but these two are latest Flux with and without the multiplication optimized for cpu.

racinmat commented 3 years ago

And it seems the vectorized index dispatches on https://github.com/JuliaLang/julia/blob/v1.6.0/base/abstractarray.jl#L1167-L1171, so I have no idea what could be profiled or improved there.

CarloLucibello commented 3 years ago

On my laptop with this script

using Flux, CUDA, LinearAlgebra, BenchmarkTools, NNlib, NNlibCUDA
using Flux: onehotbatch

function mul0(A::AbstractMatrix, B::Flux.OneHotMatrix)
  A * B
end

function mul1(A::AbstractMatrix, B::Flux.OneHotMatrix)
  m = size(A,1)
  Y = similar(A, m, size(B,2))
  for (j, ix) in enumerate(B.indices)
    for i in 1:m
      @inbounds Y[i,j] = A[i,ix]
    end
  end
  Y
end

function mul2(A::AbstractMatrix, B::Flux.OneHotMatrix)
  NNlib.gather(A, B.indices)
end

bs = 128;
Din = 100;
Dout = Din;

A = rand(Float32, Dout, Din);
oh = onehotbatch(rand(1:Din, bs), 1:Din);

@assert mul0(A,oh) == mul1(A,oh) == mul2(A,oh)

println("# mul0")
@btime mul0($A, $oh);
println("# mul1")
@btime mul1($A, $oh);
println("# mul2")
@btime mul2($A, $oh);

I get

# mul0
  11.135 μs (3 allocations: 51.22 KiB)
# mul1
  2.341 μs (2 allocations: 50.08 KiB)
# mul2
  3.337 μs (2 allocations: 50.08 KiB)

Notice that your implementation (mul1) is very similar to the one for cpu gather, likely the performance difference is due to the use of @inbounds.

CarloLucibello commented 3 years ago

With those sizes I see the same timings for mul0 and mul2 on cpu and gpu, but for larger sizes you get a noticeable speedup


bs = 512;
Din = 1000;
Dout = Din;

A = rand(Float32, Dout, Din);
oh = onehotbatch(rand(1:Din, bs), 1:Din);

@assert mul0(A,oh) == mul1(A,oh) == mul2(A,oh)

println("# mul0")
@btime mul0($A, $oh);
println("# mul1")
@btime mul1($A, $oh);
println("# mul2")
@btime mul2($A, $oh);

gA, goh = A |> gpu, oh |> gpu;

println("# gpu mul0")
@btime mul0($gA, $goh);
println("# gpu mul1")
@btime mul1($gA, $goh);
println("# gpu mul2")
@btime mul2($gA, $goh);
# mul0
  404.782 μs (3 allocations: 1.96 MiB)
# mul1
  138.125 μs (2 allocations: 1.95 MiB)
# mul2
  144.124 μs (2 allocations: 1.95 MiB)
# gpu mul0
  11.010 μs (65 allocations: 2.86 KiB)
# gpu mul1
  7.243 s (3073541 allocations: 500.25 MiB)
# gpu mul2
  4.128 μs (36 allocations: 1.58 KiB)
DhairyaLGandhi commented 3 years ago

Well mul1 isn't really meant to be run on GPUs so that's unfair to test with GPUs, and otherwise we try to remain generic.

racinmat commented 3 years ago

If the mul2 is significantly faster for both CPU and GPU, is there a good reason against adding it? It seems quite generic, and the speedup seems quite significant to me. I would like to have such fast multiplications in Flux. Or is there a better place where to have them?

CarloLucibello commented 3 years ago

We have a bug with gradient of mul0 (our current implementation for onehot multiplication). With arguments on gpu it returns cpu arrays:


julia> gradient(A -> sum(mul2(A, goh)), gA)[1] |> typeof
CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}

julia> gradient(A -> sum(mul0(A, goh)), gA)[1] |> typeof
Matrix{Float32} (alias for Array{Float32, 2})

Also, it is very slow (on gpu)

bs = 100;
Din = 10;
Dout = Din;

A = rand(Float32, Dout, Din);
oh = onehotbatch(rand(1:Din, bs), 1:Din);

@assert mul0(A,oh) == mul1(A,oh) == mul2(A,oh)

println("# mul0")
@btime mul0($A, $oh);
println("# mul1")
@btime mul1($A, $oh);
println("# mul2")
@btime mul2($A, $oh);

gA, goh = A |> gpu, oh |> gpu;

println("# gpu mul0")
@btime mul0($gA, $goh);
println("# gpu mul1")
@btime mul1($gA, $goh);
println("# gpu mul2")
@btime mul2($gA, $goh);

grad0 = gradient(A -> sum(mul0(A, oh)), A)[1]
gradg0 = gradient(A -> sum(mul0(A, goh)), gA)[1]
@assert Array(gradg0) ≈ grad0

grad2 = gradient(A -> sum(mul2(A, oh)), A)[1]
gradg2 = gradient(A -> sum(mul2(A, goh)), gA)[1]
@assert grad2 ≈ grad0
@assert Array(gradg2) ≈ grad2

println("# grad mul0")
@btime gradient(A -> sum(mul0(A, $oh)), $A)[1]
# println("# grad mul1") # errors out since mutates
# @btime gradient(A -> sum(mul1(A, oh)), A)[1]
println("# grad mul2")
@btime gradient(A -> sum(mul2(A, $oh)), $A)[1]

println("# grad gpu mul0")
@btime gradient(A -> sum(mul0(A, $goh)), $gA)[1]
# println("# grad mul1") # errors out since mutates
# @btime gradient(A -> sum(mul1(A, oh)), A)[1]
println("# grad gpu mul2")
@btime gradient(A -> sum(mul2(A, $goh)), $gA)[1]
# mul0
  1.196 μs (2 allocations: 4.94 KiB)
# mul1
  623.688 ns (1 allocation: 4.06 KiB)
# mul2
  1.736 μs (1 allocation: 4.06 KiB)
# gpu mul0
  11.530 μs (64 allocations: 2.84 KiB)
# gpu mul1
  13.958 ms (6305 allocations: 1.03 MiB)
# gpu mul2
  4.422 μs (32 allocations: 1.52 KiB)
# grad mul0
  20.340 μs (29 allocations: 14.73 KiB)
# grad mul2
  45.649 μs (521 allocations: 43.64 KiB)
# grad gpu mul0
  14.573 ms (6172 allocations: 1009.88 KiB)
# grad gpu mul2
  50.587 μs (146 allocations: 7.94 KiB)
CarloLucibello commented 3 years ago

If the mul2 is significantly faster for both CPU and GPU, is there a good reason against adding it?

No, we should do it, it also fixes the gpu bug above. Would you like to file a PR?

racinmat commented 3 years ago

Yes, I'll make a PR.

DhairyaLGandhi commented 3 years ago

Yeah, this mixes the meaning of mul with that of gather. Fixing the performance of GPU adjoints should be the fix in this case.

CarloLucibello commented 3 years ago

Yeah, this mixes the meaning of mul with that of gather. Fixing the performance of GPU adjoints should be the fix in this case.

This comment doesn't make sense. Have you seen the current definition of onehot's mul? What would you say, it mixes the meaning of getindex and mul? For onehot's matrices efficient implementations are index/gather operation, and the one which is faster, correct, and support gpu should be selected

DhairyaLGandhi commented 3 years ago

I agree on choosing the correct definition. And indexing and gather are different enough (one can gather over different dimensions whereas matmul is a standard) that I think it's best to keep the mul.

ToucheSir commented 3 years ago

While the discussion thus far has been about x*y and a previous PR implemented fast paths for x'*y, the bigger gap in @racinmat's benchmarks is actually x*y'. This isn't caught by the dispatch at https://github.com/FluxML/Flux.jl/blob/master/src/onehot.jl#L222 and also can't be expressed with gather because indices are "repeated" in the transposed one-hot matrix, but can be approximated with scatter:

julia> @btime $x*$y';
  720.190 μs (8 allocations: 39.47 KiB)

julia> @btime fast_mul($x, $y');
  4.242 μs (2 allocations: 39.14 KiB)

julia> @btime Flux.NNlib.scatter!(+, zeros(Float32, 100, 100), $x', $y.indices);
  46.070 μs (503 allocations: 71.97 KiB)

Where fast_mul is adapted from the linked benchmark:

function fast_mul(A::AbstractMatrix, B::Adjoint{Bool,<: Flux.OneHotArray})
    m = size(A,1)
    Y = fill!(similar(A, m, size(B,2)), zero(eltype(A)))
    for (j, ix) in enumerate(parent(B).indices)
        for i in 1:m
            @inbounds Y[i,ix] += A[i,j]
        end
    end
    Y
end

It's not altogether clear to me why fast_mul is, well, so much faster. Collecting x' into a separate variable or using (.+) instead of + don't seem to speed up scatter at all. Moreover, even if I remove the @inbounds from fast_mul, it's still ~4x faster:

julia> @btime fast_mul_checked($x, $y');
  11.549 μs (2 allocations: 39.14 KiB)

So unless we can diagnose why scatter[!] is slower, this seems like a good candidate for inclusion into NNlib

DhairyaLGandhi commented 3 years ago

Ah thanks Brian! I was blanking on the exact case re gather, it was so simple in hindsight. Approximating with scatter seems like it would lead to similar corner cases. I think the best answer is to see if the current implementation (and the adjoint) is running into a generic fallback somewhere

ToucheSir commented 3 years ago

For the current definition of scatter? Not that I can tell, unfortunately. The llvm and native asm are different, but not substantially so and you see many of the same paths invoked (e.g. vectorized setindex). I'm afraid I don't have a good enough sense of the performance breakdown to understand where the difference.

The current definition of *, on the other hand, is 100% falling back to the generic matmul in Base. Since the one-hot array is technically sparse, I don't think we can lean on BLAS for this either. Thankfully, the hand-written implementation is pretty concise and could easily be translated into a GPU kernel as well (there is 0 conditional logic). I only wish we could merge both routines in https://github.com/racinmat/flux_benchmarks/blob/master/0_12_7_tricks/main.jl#L7-L28 somehow, but after looking at the inner loops any similarities are likely misleading.

CarloLucibello commented 3 years ago

We can implement immediately

function Base.(:*)(A::AbstractMatrix, B::Adjoint{Bool, <:Flux.OneHotMatrix)
  NNlib.scatter(+, A, parent(B).indices, dstsize(size(A,1), size(B,2))
end

to obtain a noticeable cpu and gpu speedup. Since the mul0 and fastmul implementations here are essentially implementations of scatter and gather with very tiny differences from the ones we have in NNlib, the fact that they are a bit faster means we have room for improving NNlib.gather and NNlib.scatter cpu implementation. So let's also add scatter version for the onehot adjoint mul and bring the performance discussion to NNlib so that the benefits can be much more widespread