LuxDL / Lux.jl

Elegant & Performant Scientific Machine Learning in Julia
https://lux.csail.mit.edu/
MIT License
446 stars 50 forks source link

Patch a compact bug #648

Closed avik-pal closed 1 month ago

avik-pal commented 1 month ago
avik-pal commented 1 month ago
using Lux, Random

function MyDense(; d_in=5, d_out=7, act=relu)
    @compact(W=randn(d_out, d_in), b=zeros(d_out), incr=1) do x
        y = W * x
        incr *= 10
        @return act.(y .+ b) .+ incr
    end
end

model = MyDense()

ps, st = Lux.setup(Xoshiro(0), model)

model(ones(5, 10), ps, st)

x = rand(Float32, (5, 10))

@code_warntype model(x, ps, st)

@code_warntype Zygote.gradient(lfn, model, x, ps, st)