FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.46k stars 603 forks source link

Enzyme fails with MultiHeadAttention layer #2448

Open mashu opened 3 months ago

mashu commented 3 months ago

I am attaching MWE where Zygote (default of Flux) works fine but Enzyme fails compilation (@wsmoses )

using Enzyme
using Flux
using CUDA

_make_zero(x::Union{Number,AbstractArray}) = zero(x)
_make_zero(x) = x
make_zero(model) = fmap(_make_zero, model)

function gradient_ez(f, x...)
    args = []
    for x in x
        if x isa Number
            push!(args, Active(x))
        else
            push!(args, Duplicated(x, make_zero(x)))
        end
    end
    ret = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...)
    g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
    return g
end

x = CUDA.rand(Float32, 64, 100, 512)
mha = MultiHeadAttention(64 => 64 => 64) |> gpu

# Failing
Δ = gradient_ez(mha) do m
    sum(first(m(x, x, x)))
end

# Working
Δ = Flux.gradient(mha) do m
    sum(first(m(x, x, x)))
end
mashu commented 3 months ago

The same code run on CUDA

@btime CUDA.@sync Flux.gradient(mha) do m
           sum(first(m(x, x, x)))
       end
 11.983 ms (2583 allocations: 137.55 KiB)

whereas

@btime CUDA.@sync gradient_ez(mha) do m
           sum(first(m($x, $x, $x)))
       end

 ....
   [2] EnzymeCreateAugmentedPrimal(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnUsed::Bool, shadowReturnUsed::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, forceAnonymousTape::Bool, width::Int64, atomicAdd::Bool)
    @ Enzyme.API ~/.julia/packages/Enzyme/srACB/src/api.jl:190
  [3] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…}, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler.jl:3141
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5074
  [5] codegen
    @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:4481 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5771
  [7] _thunk
    @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5771 [inlined]
  [8] cached_compilation
    @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5809 [inlined]
  [9] (::Enzyme.Compiler.var"#560#561"{DataType, DataType, DataType, Enzyme.API.CDerivativeMode, NTuple{4, Bool}, Int64, Bool, Bool, UInt64, DataType})(ctx::LLVM.Context)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5875
 [10] JuliaContext(f::Enzyme.Compiler.var"#560#561"{DataType, DataType, DataType, Enzyme.API.CDerivativeMode, NTuple{…}, Int64, Bool, Bool, UInt64, DataType}; kwargs::@Kwargs{})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:52
 [11] JuliaContext(f::Function)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:42
 [12] #s2027#559
    @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5827 [inlined]

So it can be reproduced with following packages and Julia 1.10.3

  [6e4b80f9] BenchmarkTools v1.5.0
  [052768ef] CUDA v5.3.4
  [082447d4] ChainRules v1.66.0
  [d360d2e6] ChainRulesCore v1.23.0
  [7da242da] Enzyme v0.12.6
  [587475ba] Flux v0.14.15
  [e88e6eb3] Zygote v0.6.70
  [02a925ec] cuDNN v1.3.1

Thanks!

wsmoses commented 3 months ago

@mashu can you post the whole log?

mashu commented 3 months ago

I was convinced I attached it earlier, but apparently I didn't so here it is MWA.log.gz The following code was run as

julia --project=@. src/MWA.jl 2> MWA.log


using Enzyme
using Flux
using CUDA

_make_zero(x::Union{Number,AbstractArray}) = zero(x) _make_zero(x) = x make_zero(model) = fmap(_make_zero, model)

function gradient_ez(f, x...) args = [] for x in x if x isa Number push!(args, Active(x)) else push!(args, Duplicated(x, make_zero(x))) end end ret = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...) g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x)) return g end

x = CUDA.rand(Float32, 64, 100, 512) mha = MultiHeadAttention(64 => 64 => 64) |> gpu

Flux.gradient(mha) do m sum(first(m(x, x, x))) end

Δ = gradient_ez(mha) do m sum(first(m(x, x, x))) end

wsmoses commented 3 months ago

Also does this wokr on CPU?

mashu commented 3 months ago

@wsmoses Initially I got compilation error with CPU version, but after moving to separate project (MWE) it only fails for GPU. Having said that, I still can't figure out why it fails in my main project, as packages are up to date and basically the same version. But this GPU failure is at least reproducible.

wsmoses commented 3 months ago

GPU is in progress so the report is super helpful but also presently expected.

Maybe check the current versions of packages in your project and see if it's forcing an older Enzyme?

mashu commented 3 months ago

It's the same version of ⌅ [7cc45869] Enzyme_jll v0.0.109+0 in both working and non-working version. Must be some indirect dependency that I can't figure out. As for the GPU part, my impression is that CPU paths are sometimes slow in Flux and not well optimized, probably because most people use GPU paths for any work.

wsmoses commented 3 months ago

Ah but what's your Enzyme version (rather than Enzyme_jll which is a dependncy)

mashu commented 3 months ago

Looks the same v0.12.6

Working MWE ]st

  [6e4b80f9] BenchmarkTools v1.5.0
  [052768ef] CUDA v5.3.4
  [082447d4] ChainRules v1.66.0
  [d360d2e6] ChainRulesCore v1.23.0
  [7da242da] Enzyme v0.12.6
  [587475ba] Flux v0.14.15
  [e88e6eb3] Zygote v0.6.70
  [02a925ec] cuDNN v1.3.1

Broken one ]st

  [6e4b80f9] BenchmarkTools v1.5.0
  [336ed68f] CSV v0.10.14
  [052768ef] CUDA v5.3.4
  [082447d4] ChainRules v1.66.0
  [d360d2e6] ChainRulesCore v1.23.0
  [a93c6f00] DataFrames v1.6.1
  [864edb3b] DataStructures v0.18.20
  [31c24e10] Distributions v0.25.108
  [7da242da] Enzyme v0.12.6
  [c2308a5c] FASTX v2.1.5
  [587475ba] Flux v0.14.15
  [41a02a25] Folds v0.2.10
  [033835bb] JLD2 v0.4.47
  [682c06a0] JSON v0.21.4
  [e6f89c97] LoggingExtras v1.0.3
  [12afc1b8] NeuralAttentionlib v0.2.13
  [0b1bfda6] OneHotArrays v0.2.5
  [3bd65402] Optimisers v0.3.3
  [d7d3b36b] ParameterSchedulers v0.4.1
  [92933f4c] ProgressMeter v1.10.0
  [2913bbd2] StatsBase v0.34.3
  [b8865327] UnicodePlots v3.6.4
  [02a925ec] cuDNN v1.3.1
  [56ddb016] Logging
mashu commented 3 months ago

Also including log with error that happens CPU side on the broken project, not sure if that helps though. CPU.log

wsmoses commented 3 months ago

From the log I think the simplest answer here is we should just add the attention custom derivative in nnlib. I assume there's one already for CR?

If so you can try our import CR rule into enzyme macro as a test to see if anything else fails, while in the interim we can look at making a fast rule for (CR rules will be slower and come with caveats)

mashu commented 3 months ago

@wsmoses Long story short, I wanted to use Enzyme, because I often lack skills to write rrule and there is none for MultiHeadAttention in NNlib. Longer answer is that I am using currently NeuralAttentionlib.jl which is part of Transformers.jl which has customization to layer I need and rrule that makes that variant of MHA couple of times faster on GPU. My hope was that maybe Enzyme does better job than Zygote when it comes to performance of the code it produces (when no rrule is provided).

wsmoses commented 3 months ago

If you can wait a short bit (it's currently unregistered and there's a bunch of small things we should add), Reactant.jl is an execution engine (eg does tons of fancy optimizations/kernel fusion), is both Enzyme and GPU compatible out of the box, and might be what you're looking for.

In the interim I'll push on the GPU support for native Enztme here too, but just throwing that out there if helpful.

https://github.com/EnzymeAD/Reactant.jl