Open Todorbsc opened 5 days ago
CC @mofeing There's an error when trying to compile a function named after f:
f
julia> using Enzyme julia> using Reactant AssertionError("Could not find registered platform with name: \"cuda\". Available platform names are: ") julia> using Adapt julia> N = 10 10 julia> params = rand(ComplexF64, N); julia> expected = rand(ComplexF64, N); julia> params′ = adapt(ConcreteRArray, params); julia> expected′ = adapt(ConcreteRArray, expected); julia> function f(params, expected) return sum(abs.(expected - params)) end f (generic function with 1 method) julia> function ∇f(params, expected) foo = Enzyme.gradient(ReverseWithPrimal, f, params, Enzyme.Const(expected)) return foo.val, foo.derivs[1] end ∇f (generic function with 1 method) julia> params_real = rand(N); julia> expected_real = rand(N); julia> params_real′ = adapt(ConcreteRArray, params_real); julia> expected_real′ = adapt(ConcreteRArray, expected_real); julia> ∇fR = Reactant.@compile ∇f(params_real′, expected_real′) ERROR: invalid redefinition of constant Main.f Stacktrace: [1] top-level scope @ ~/.julia/packages/Reactant/e7PeE/src/Compiler.jl:488
However, it works when using a function name different than f, like f1:
f1
julia> params_real = rand(N); julia> expected_real = rand(N); julia> params_real′ = adapt(ConcreteRArray, params_real); julia> expected_real′ = adapt(ConcreteRArray, expected_real); julia> function f1(params, expected) return sum(abs.(expected - params)) end f1 (generic function with 1 method) julia> function ∇f(params, expected) foo = Enzyme.gradient(ReverseWithPrimal, f1, params, Enzyme.Const(expected)) return foo.val, foo.derivs[1] end ∇f (generic function with 1 method) julia> ∇fR = Reactant.@compile ∇f(params_real′, expected_real′) Reactant.Compiler.Thunk{Symbol("##∇f_reactant#225")}() julia> ∇fR(params_real′, expected_real′) (ConcreteRNumber{Float64}(2.9273890179275863), ConcreteRArray{Float64, 1}([-1.0, -1.0, -1.0, 1.0, -1.0, 1.0, 1.0, 1.0, -1.0, -1.0])) julia> ∇f(params_real, expected_real) (2.9273890179275868, [-1.0, -1.0, -1.0, 1.0, -1.0, 1.0, 1.0, 1.0, -1.0, -1.0])
We need to use gensym in the compile macro 😓
yep 😅
CC @mofeing There's an error when trying to compile a function named after
f
:However, it works when using a function name different than
f
, likef1
: