Closed avik-pal closed 1 month ago
using Lux, Random
function MyDense(; d_in=5, d_out=7, act=relu)
@compact(W=randn(d_out, d_in), b=zeros(d_out), incr=1) do x
y = W * x
incr *= 10
@return act.(y .+ b) .+ incr
end
end
model = MyDense()
ps, st = Lux.setup(Xoshiro(0), model)
model(ones(5, 10), ps, st)
x = rand(Float32, (5, 10))
@code_warntype model(x, ps, st)
@code_warntype Zygote.gradient(lfn, model, x, ps, st)
@return