LuxDL / Lux.jl

Elegant & Performant Scientific Machine Learning in Julia
https://lux.csail.mit.edu/
MIT License
456 stars 54 forks source link

Pullback over jacobian (with CUDA) #602

Closed aksuhton closed 2 months ago

aksuhton commented 2 months ago

It's huge for me that Lux v0.5.38 allows one to do pullbacks (with respect to parameters) over jacobians (with respect to model inputs) on the cpu. With CUDA, though, there is a scalar indexing error. I'll add an MWE below and also link here the previous Zygote issue: https://github.com/FluxML/Zygote.jl/issues/1505

Thank you for taking the time to look this over!

using Lux, LuxCUDA, ComponentArrays
using Zygote
using Random, LinearAlgebra, Statistics
##
Lux.@concrete struct ModelA <: Lux.AbstractExplicitContainerLayer{(:chain,)}
    chain::Lux.AbstractExplicitLayer
end
##
function ModelA(chs::Pair{Int, Int})
    chain = Dense(chs)
    return ModelA(chain)
end
##
function (m::ModelA)(
         x::AbstractArray{T, 2}, ps, st::NamedTuple) where {T}
    ##
    potential = StatefulLuxLayer(m.chain, ps, st)
    ## Future -> batched_jacobian
    m_x = reshape(diag(only(Zygote.jacobian(potential, x))), size(x))
    ##
    return m_x, potential.st
end
##
function Loss_st(x, y, model, ps, st)
    ŷ, st_ = Lux.apply(model, x, ps, st)
    l = mean(abs2.(ŷ .- y))
    return l, st_
end
##
function test_forward(dev::Function)
    ##
    x = randn(Float32, 5, 3) |> dev
    m = ModelA(5 => 5)
    ps, st = Lux.setup(Random.default_rng(), m)
    ps = ComponentArray(ps)
    ps = ps |> dev
    st = st |> dev
    ##
    m_x, st_ = m(x, ps, st)
    ##
    return m_x
end
##
function test_backward(dev::Function)
    ##
    x = randn(Float32, 5, 3) |> dev
    m = ModelA(5 => 5)
    ps, st = Lux.setup(Random.default_rng(), m)
    ps = ComponentArray(ps)
    ps = ps |> dev
    st = st |> dev
    ##
    m_x, st_ = m(x, ps, st)
    y = randn(eltype(m_x), size(m_x)) |> dev
    ##
    (l, st_), back = Zygote.pullback(p -> Loss_st(x, y, m, p, st), ps)
    gs = back((one(l), nothing))[1]
    ##
    return l, gs
end

Now the failure point is

test_backward(gpu_device())

with stacktrace


Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore should be avoided.

If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
to enable scalar iteration globally or for the operations in question.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] errorscalar(op::String)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:155
  [3] _assertscalar(op::String, behavior::GPUArraysCore.ScalarIndexing)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:128
  [4] assertscalar(op::String)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:116
  [5] getindex
    @ ~/.julia/packages/GPUArrays/OKkAu/src/host/indexing.jl:48 [inlined]
  [6] getindex
    @ ~/.julia/juliaup/julia-1.10.2+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/diagonal.jl:156 [inlined]
  [7] _unsafe_getindex_rs
    @ ./reshapedarray.jl:264 [inlined]
  [8] _unsafe_getindex
    @ ./reshapedarray.jl:261 [inlined]
  [9] getindex
    @ ./reshapedarray.jl:249 [inlined]
 [10] getindex
    @ ./subarray.jl:290 [inlined]
 [11] getindex
    @ ./reshapedarray.jl:244 [inlined]
 [12] _getindex
    @ ./abstractarray.jl:1324 [inlined]
 [13] getindex
    @ ./abstractarray.jl:1291 [inlined]
 [14] _broadcast_getindex
    @ ./broadcast.jl:675 [inlined]
 [15] _getindex
    @ ./broadcast.jl:706 [inlined]
 [16] _broadcast_getindex
    @ ./broadcast.jl:681 [inlined]
 [17] _getindex
    @ ./broadcast.jl:706 [inlined]
 [18] _broadcast_getindex
    @ ./broadcast.jl:681 [inlined]
 [19] getindex
    @ ./broadcast.jl:636 [inlined]
 [20] macro expansion
    @ ./broadcast.jl:1004 [inlined]
 [21] macro expansion
    @ ./simdloop.jl:77 [inlined]
 [22] copyto!
    @ ./broadcast.jl:1003 [inlined]
 [23] copyto!
    @ ./broadcast.jl:956 [inlined]
 [24] copy
    @ ./broadcast.jl:928 [inlined]
 [25] materialize
    @ ./broadcast.jl:903 [inlined]
 [26] __forwarddiff_jvp(f::LuxZygoteExt.var"#16#20"{LuxZygoteExt.var"#15#19"{Int64, StatefulLuxLayer{true, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = ViewAxis(26:30, ShapedAxis((5, 1))))}}}, @NamedTuple{}}}}, x::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Δx::Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, Base.ReshapedArray{Float32, 2, Diagonal{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, false}, Tuple{}}, ps::ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = ViewAxis(26:30, ShapedAxis((5, 1))))}}})
    @ LuxForwardDiffExt ~/.julia/packages/Lux/ktWbJ/ext/LuxForwardDiffExt.jl:30
 [27] #14
    @ ~/.julia/packages/Lux/ktWbJ/ext/LuxZygoteExt.jl:103 [inlined]
 [28] MappingRF
    @ ./reduce.jl:100 [inlined]
 [29] _foldl_impl(op::Base.MappingRF{LuxZygoteExt.var"#14#18"{StatefulLuxLayer{true, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = ViewAxis(26:30, ShapedAxis((5, 1))))}}}, @NamedTuple{}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = ViewAxis(26:30, ShapedAxis((5, 1))))}}}}, Base.BottomRF{typeof(Lux.__internal_add)}}, init::Base._InitialValue, itr::Base.Iterators.Enumerate{RowSlices{Base.ReshapedArray{Float32, 2, Diagonal{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Base.OneTo{Int64}}, SubArray{Float32, 1, Base.ReshapedArray{Float32, 2, Diagonal{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, false}}})
    @ Base ./reduce.jl:58
 [30] foldl_impl
    @ ./reduce.jl:48 [inlined]
 [31] mapfoldl_impl
    @ ./reduce.jl:44 [inlined]
 [32] mapfoldl
    @ ./reduce.jl:175 [inlined]
 [33] mapreduce
    @ ./reduce.jl:307 [inlined]
 [34] #13
    @ ~/.julia/packages/Lux/ktWbJ/ext/LuxZygoteExt.jl:101 [inlined]
 [35] ZBack
    @ ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:211 [inlined]
 [36] jacobian
    @ ~/.julia/packages/Lux/ktWbJ/ext/LuxZygoteExt.jl:81 [inlined]
 [37] ModelA
    @ ~/GitHub/G2TWIN/dev/mwe.jl:20 [inlined]
 [38] (::Zygote.Pullback{Tuple{ModelA, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = ViewAxis(26:30, ShapedAxis((5, 1))))}}}, @NamedTuple{}}, Tuple{Zygote.var"#2013#back#204"{typeof(identity)}, Zygote.ZBack{ChainRules.var"#size_pullback#917"}, Zygote.Pullback{Tuple{typeof(jacobian), StatefulLuxLayer{true, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = ViewAxis(26:30, ShapedAxis((5, 1))))}}}, @NamedTuple{}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:ps, Zygote.Context{false}, StatefulLuxLayer{true, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = ViewAxis(26:30, ShapedAxis((5, 1))))}}}, @NamedTuple{}}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = ViewAxis(26:30, ShapedAxis((5, 1))))}}}}}, Zygote.ZBack{LuxZygoteExt.var"#13#17"{StatefulLuxLayer{true, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = ViewAxis(26:30, ShapedAxis((5, 1))))}}}, @NamedTuple{}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = ViewAxis(26:30, ShapedAxis((5, 1))))}}}, Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}}}}, Zygote.Pullback{Tuple{typeof(only), Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Tuple{Zygote.var"#2029#back#213"{Zygote.var"#back#211"{1, 1, Zygote.Context{false}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:st, Zygote.Context{false}, StatefulLuxLayer{true, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = ViewAxis(26:30, ShapedAxis((5, 1))))}}}, @NamedTuple{}}, @NamedTuple{}}}, Zygote.var"#2763#back#609"{Zygote.var"#603#607"{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Tuple{Int64, Int64}}}}, Zygote.ZBack{ChainRules.var"#diag_pullback#2059"}, Zygote.Pullback{Tuple{Type{StatefulLuxLayer}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = ViewAxis(26:30, ShapedAxis((5, 1))))}}}, @NamedTuple{}}, Tuple{Zygote.Pullback{Tuple{Lux.var"##StatefulLuxLayer#214", Val{true}, Type{StatefulLuxLayer}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = ViewAxis(26:30, ShapedAxis((5, 1))))}}}, @NamedTuple{}}, Any}, Zygote.var"#1922#back#161"{Zygote.var"#157#160"}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:chain, Zygote.Context{false}, ModelA, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}}})(Δ::Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [39] apply
    @ ~/.julia/packages/LuxCore/8lRV2/src/LuxCore.jl:180 [inlined]
 [40] Loss_st
    @ ~/GitHub/G2TWIN/dev/mwe.jl:26 [inlined]
 [41] (::Zygote.Pullback{Tuple{typeof(Loss_st), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ModelA, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = ViewAxis(26:30, ShapedAxis((5, 1))))}}}, @NamedTuple{}}, Tuple{Zygote.var"#back#246"{Zygote.var"#2029#back#213"{Zygote.var"#back#211"{2, 2, Zygote.Context{false}, @NamedTuple{}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(LuxCore.apply), ModelA, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = ViewAxis(26:30, ShapedAxis((5, 1))))}}}, @NamedTuple{}}, Tuple{Zygote.Pullback{Tuple{ModelA, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = ViewAxis(26:30, ShapedAxis((5, 1))))}}}, @NamedTuple{}}, Tuple{Zygote.var"#2013#back#204"{typeof(identity)}, Zygote.ZBack{ChainRules.var"#size_pullback#917"}, Zygote.Pullback{Tuple{typeof(jacobian), StatefulLuxLayer{true, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = ViewAxis(26:30, ShapedAxis((5, 1))))}}}, @NamedTuple{}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:ps, Zygote.Context{false}, StatefulLuxLayer{true, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = ViewAxis(26:30, ShapedAxis((5, 1))))}}}, @NamedTuple{}}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = ViewAxis(26:30, ShapedAxis((5, 1))))}}}}}, Zygote.ZBack{LuxZygoteExt.var"#13#17"{StatefulLuxLayer{true, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = ViewAxis(26:30, ShapedAxis((5, 1))))}}}, @NamedTuple{}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = ViewAxis(26:30, ShapedAxis((5, 1))))}}}, Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}}}}, Zygote.Pullback{Tuple{typeof(only), Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Tuple{Zygote.var"#2029#back#213"{Zygote.var"#back#211"{1, 1, Zygote.Context{false}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:st, Zygote.Context{false}, StatefulLuxLayer{true, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = ViewAxis(26:30, ShapedAxis((5, 1))))}}}, @NamedTuple{}}, @NamedTuple{}}}, Zygote.var"#2763#back#609"{Zygote.var"#603#607"{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Tuple{Int64, Int64}}}}, Zygote.ZBack{ChainRules.var"#diag_pullback#2059"}, Zygote.Pullback{Tuple{Type{StatefulLuxLayer}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = ViewAxis(26:30, ShapedAxis((5, 1))))}}}, @NamedTuple{}}, Tuple{Zygote.Pullback{Tuple{Lux.var"##StatefulLuxLayer#214", Val{true}, Type{StatefulLuxLayer}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = ViewAxis(26:30, ShapedAxis((5, 1))))}}}, @NamedTuple{}}, Any}, Zygote.var"#1922#back#161"{Zygote.var"#157#160"}}}, Zygote.var"#2180#back#303"{Zygote.var"#back#302"{:chain, Zygote.Context{false}, ModelA, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}}}}}, Zygote.var"#3972#back#1289"{Zygote.var"#1285#1288"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.var"#2013#back#204"{typeof(identity)}, Zygote.var"#3764#back#1191"{Zygote.var"#1187#1190"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.var"#2029#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, @NamedTuple{}}}, Zygote.ZBack{ChainRules.var"#mean_pullback#1826"{Int64, ChainRules.var"#sum_pullback#1638"{Colon, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ChainRulesCore.ProjectTo{AbstractArray, @NamedTuple{element::ChainRulesCore.ProjectTo{Float32, @NamedTuple{}}, axes::Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}, Zygote.var"#2029#back#213"{Zygote.var"#back#211"{2, 2, Zygote.Context{false}, Int64}}, Zygote.var"#2029#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, Zygote.var"#back#245"{Zygote.var"#2029#back#213"{Zygote.var"#back#211"{2, 1, Zygote.Context{false}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}}}})(Δ::Tuple{Float32, Nothing})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [42] #3
    @ ~/GitHub/G2TWIN/dev/mwe.jl:57 [inlined]
 [43] (::Zygote.Pullback{Tuple{var"#3#4"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ModelA, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = ViewAxis(26:30, ShapedAxis((5, 1))))}}}}, Any})(Δ::Tuple{Float32, Nothing})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [44] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#3#4"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ModelA, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = ViewAxis(26:30, ShapedAxis((5, 1))))}}}}, Any}})(Δ::Tuple{Float32, Nothing})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:91
 [45] test_backward(dev::Function)
    @ Main ~/GitHub/G2TWIN/dev/mwe.jl:58
 [46] top-level scope
    @ REPL[13]:1
 [47] top-level scope
    @ ~/.julia/packages/CUDA/Qbxch/src/initialization.jl:206
avik-pal commented 2 months ago

I see, the problem comes from $\Delta x$ becoming a

Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, Base.ReshapedArray{Float32, 2, Diagonal{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}},Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, false}, Tuple{}}

which makes broadcasting over it a bit nasty for GPUArrays. Let me see what can be done here

avik-pal commented 2 months ago

Okay that PR should do it, I will add a few tests and merge it

avik-pal commented 2 months ago

1 pointer, you can write your model like

model = @compact(; potential=Dense(5 => 5, gelu)) do x
    return reshape(diag(only(Zygote.jacobian(potential, x))), size(x))
end

ps, st = Lux.setup(Random.default_rng(), model)
x = randn(Float32, 5, 3)
model(x, ps, st)

That way Lux takes care of wrapping the layer in a StatefulLuxLayer, so your code is less verbose.

aksuhton commented 2 months ago

Okay that PR should do it, I will add a few tests and merge it

It is fixed on my end too, incredible! tysm