Closed aksuhton closed 2 months ago
I see, the problem comes from $\Delta x$ becoming a
Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, Base.ReshapedArray{Float32, 2, Diagonal{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}},Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, false}, Tuple{}}
which makes broadcasting over it a bit nasty for GPUArrays. Let me see what can be done here
Okay that PR should do it, I will add a few tests and merge it
1 pointer, you can write your model like
model = @compact(; potential=Dense(5 => 5, gelu)) do x
return reshape(diag(only(Zygote.jacobian(potential, x))), size(x))
end
ps, st = Lux.setup(Random.default_rng(), model)
x = randn(Float32, 5, 3)
model(x, ps, st)
That way Lux takes care of wrapping the layer in a StatefulLuxLayer, so your code is less verbose.
Okay that PR should do it, I will add a few tests and merge it
It is fixed on my end too, incredible! tysm
It's huge for me that Lux v0.5.38 allows one to do pullbacks (with respect to parameters) over jacobians (with respect to model inputs) on the cpu. With CUDA, though, there is a scalar indexing error. I'll add an MWE below and also link here the previous Zygote issue: https://github.com/FluxML/Zygote.jl/issues/1505
Thank you for taking the time to look this over!
Now the failure point is
with stacktrace