LuxDL / LuxLib.jl

Backend for Lux.jl
MIT License
9 stars 0 forks source link

feat: add enzyme reverse rules for `fused_dense!` #151

Closed avik-pal closed 1 month ago

avik-pal commented 2 months ago

fixes #148

TODO

Test case

using CUDA, LuxLib, Enzyme, NNlib, Zygote

function fused_dense!(y, act, weight, x, b)
    op = LuxLib.internal_operation_mode((y, weight, x, b))
    LuxLib.Impl.fused_dense!(y, op, act, weight, x, b)
    return
end

# CPU case
y = zeros(Float32, 2, 2)
weight = rand(Float32, 2, 2)
x = rand(Float32, 2, 2)
b = rand(Float32, 2)

fused_dense!(y, gelu, weight, x, b)

dy = rand(Float32, 2, 2)
dweight = zeros(Float32, 2, 2)
dx = zeros(Float32, 2, 2)
db = zeros(Float32, 2)

act = x -> gelu(x)

begin
    dx .= 0
    db .= 0
    dweight .= 0
    Enzyme.autodiff(
        Reverse, fused_dense!, Duplicated(y, copy(dy)), Const(act),
        Duplicated(weight, dweight), Duplicated(x, dx), Duplicated(b, db)) # Works
    @show nothing, dweight, dx, db
end

begin
    _, pb_f = Zygote.pullback(fused_dense_bias_activation, act, weight, x, b)
    @show pb_f(dy)
end

# GPU case
y = zeros(Float32, 2, 2) |> cu
weight = rand(Float32, 2, 2) |> cu
x = rand(Float32, 2, 2) |> cu
b = rand(Float32, 2) |> cu

fused_dense!(y, gelu, weight, x, b)

dy = rand(Float32, 2, 2) |> cu
dweight = zeros(Float32, 2, 2) |> cu
dx = zeros(Float32, 2, 2) |> cu
db = zeros(Float32, 2) |> cu

act = gelu

begin
    dx .= 0
    db .= 0
    dweight .= 0
    Enzyme.autodiff(
        Reverse, fused_dense!, Duplicated(y, copy(dy)), Const(act),
        Duplicated(weight, dweight), Duplicated(x, dx), Duplicated(b, db)) # Fails
end

begin
    _, pb_f = Zygote.pullback(fused_dense_bias_activation, act, weight, x, b)
    @show pb_f(dy)
end
avik-pal commented 2 months ago
ERROR: a 
No forward mode derivative found for __nv_fast_expf
 at context:   %4 = call float @__nv_fast_expf(float %3) #12, !dbg !21

Stacktrace:
 [1] #exp_fast
   @ /mnt/.julia/packages/CUDA/Tl08O/src/device/intrinsics/math.jl:130
 [2] sigmoid_fast
   @ /mnt/.julia/packages/NNlib/f92hx/src/activations.jl:830
 was thrown during kernel execution on thread (1, 1, 1) in block (1, 1, 1).
Stacktrace not available, run Julia on debug level 2 for more details (by passing -g2 to the executable).

Better error message using Enzyme#main

codecov[bot] commented 1 month ago

Codecov Report

Attention: Patch coverage is 0% with 59 lines in your changes missing coverage. Please review.

Project coverage is 78.02%. Comparing base (121a2fe) to head (e125c92). Report is 4 commits behind head on main.

Files with missing lines Patch % Lines
src/impl/dense.jl 0.00% 59 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #151 +/- ## ========================================== + Coverage 68.72% 78.02% +9.30% ========================================== Files 37 38 +1 Lines 1896 1957 +61 ========================================== + Hits 1303 1527 +224 + Misses 593 430 -163 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.