EnzymeAD / Reactant.jl

MIT License
26 stars 2 forks source link

Define `similar` for Reactant.ConcreteRArray #4

Closed avik-pal closed 1 month ago

avik-pal commented 1 month ago

The inplace Lux functions currently fail because it constructs a matrix

using Reactant, Lux, Random
using Test

# Generate some data for the XOR problem: vectors of length 2, as columns of a matrix:
noisy = rand(Float32, 2, 1000)                                        # 2×1000 Matrix{Float32}
truth = [xor(col[1] > 0.5, col[2] > 0.5) for col in eachcol(noisy)]   # 1000-element Vector{Bool}

# Define our model, a multi-layer perceptron with one hidden layer of size 3:
model = Chain(Dense(2 => 3, tanh),   # activation function inside layer
    BatchNorm(3), Dense(3 => 2), softmax)
ps, st = Lux.setup(Xoshiro(123), model)

using BenchmarkTools

origout, _ = model(noisy, ps, st)
@show origout[3]
@btime model($noisy, $ps, $st)

cmodel = Reactant.make_tracer(IdDict(), model, (), Reactant.ArrayToConcrete, nothing)
cps = Reactant.make_tracer(IdDict(), ps, (), Reactant.ArrayToConcrete, nothing)
cst = Reactant.make_tracer(IdDict(), st, (), Reactant.ArrayToConcrete, nothing)
cnoisy = Reactant.ConcreteRArray(noisy)

f = Reactant.compile((a, b, c, d) -> a(b, c, d), (cmodel, cnoisy, cps, CST))

Minimal version

using Reactant

noisy = rand(Float32, 2, 1000)
cnoisy = Reactant.ConcreteRArray(noisy)
similar(cnoisy)  # Matrix
wsmoses commented 1 month ago

Done https://github.com/EnzymeAD/Reactant.jl/commit/afc44a0642eec49007fb6853ae942af3b69da864