EnzymeAD / Reactant.jl

MIT License
60 stars 4 forks source link

How to use Reactant on Conv layers #214

Open yolhan83 opened 2 hours ago

yolhan83 commented 2 hours ago

Hello, I wonder if Reactant works with Conv layers, it seems it works in forward but not in the gradient pass, neither on cpu of gpu

version :

Julia Version 1.10.6
Commit 67dffc4a8ae (2024-10-28 12:23 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 20 × 12th Gen Intel(R) Core(TM) i7-12700H
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, alderlake)
Threads: 1 default, 0 interactive, 1 GC (on 20 virtual cores)

code :

import Enzyme
using Lux,Reactant,Statistics,Random

const rng = Random.default_rng(123)
const dev = xla_device()
model_conv = Chain(
    Conv((3,3),1=>8,pad=SamePad(),relu), 
    Lux.FlattenLayer(),
    Dense(32*32*8,10),
    softmax
)
model_dense = Chain(
    Lux.FlattenLayer(),
    Dense(32*32*1,10),
    softmax
)

ps_conv,st_conv = Lux.setup(rng, model_conv) |> dev
ps_dense,st_dense = Lux.setup(rng, model_dense) |> dev

loss(model,x,ps,st,y) = Lux.MSELoss()(first(model(x,ps,st)),y)
x = randn(rng, Float32, 32,32,1,100) |> dev
y = randn(rng, Float32, 10,100) |> dev

function get_grad(loss,model,x,ps,st,y)
    dps = Enzyme.make_zero(ps)
    Enzyme.autodiff(Enzyme.Reverse,loss,Enzyme.Const(model),Enzyme.Const(x),Enzyme.Duplicated(ps,dps),Enzyme.Const(st),Enzyme.Const(y))
    return dps
end

loss_compile_conv = @compile loss(model_conv,x,ps_conv,st_conv,y) # works
loss_compile_dense = @compile loss(model_dense,x,ps_dense,st_dense,y) # works

grad_compile_conv = @compile get_grad(loss,model_conv,x,ps_conv,st_conv,y) #doesn't work
grad_compile_dense = @compile get_grad(loss,model_dense,x,ps_dense,st_dense,y) # works

error :

error: expects input feature dimension (8) / feature_group_count = kernel input feature dimension (1). Got feature_group_count = 1.
ERROR: "failed to run pass manager on module"
Stacktrace:
  [1] run!
    @ ~/.julia/packages/Reactant/rRa4g/src/mlir/IR/Pass.jl:70 [inlined]
  [2] run_pass_pipeline!(mod::Reactant.MLIR.IR.Module, pass_pipeline::String)
    @ Reactant.Compiler ~/.julia/packages/Reactant/rRa4g/src/Compiler.jl:241
  [3] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}; optimize::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/rRa4g/src/Compiler.jl:272
  [4] compile_mlir!
    @ ~/.julia/packages/Reactant/rRa4g/src/Compiler.jl:256 [inlined]
  [5] (::Reactant.Compiler.var"#30#32"{typeof(get_grad), Tuple{…}})()
    @ Reactant.Compiler ~/.julia/packages/Reactant/rRa4g/src/Compiler.jl:584
  [6] context!(f::Reactant.Compiler.var"#30#32"{typeof(get_grad), Tuple{…}}, ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/rRa4g/src/mlir/IR/Context.jl:71
  [7] compile_xla(f::Function, args::Tuple{…}; client::Nothing)
    @ Reactant.Compiler ~/.julia/packages/Reactant/rRa4g/src/Compiler.jl:581
  [8] compile_xla
    @ ~/.julia/packages/Reactant/rRa4g/src/Compiler.jl:575 [inlined]
  [9] compile(f::Function, args::Tuple{…}; client::Nothing)
    @ Reactant.Compiler ~/.julia/packages/Reactant/rRa4g/src/Compiler.jl:608
 [10] compile(f::Function, args::Tuple{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/rRa4g/src/Compiler.jl:607
 [11] top-level scope
    @ ~/.julia/packages/Reactant/rRa4g/src/Compiler.jl:368
avik-pal commented 2 hours ago

the feature_group_count is probably missing on EnzymeJAX end?

cc @Pangoraw

Pangoraw commented 2 hours ago

The reverse is not implemented for convolution. It should be fixed in Enzyme-JAX

yolhan83 commented 2 hours ago

Oh ok I will wait for doing my MISNT benchmark then, have a nice day