EnzymeAD / Reactant.jl

MIT License
65 stars 5 forks source link

Name clash when calling `Reactant.@compile` over a function that is named `f` #237

Open Todorbsc opened 5 days ago

Todorbsc commented 5 days ago

CC @mofeing There's an error when trying to compile a function named after f:

julia> using Enzyme

julia> using Reactant
AssertionError("Could not find registered platform with name: \"cuda\". Available platform names are: ")

julia> using Adapt

julia> N = 10
10

julia> params = rand(ComplexF64, N);

julia> expected = rand(ComplexF64, N);

julia> params′ = adapt(ConcreteRArray, params);

julia> expected′ = adapt(ConcreteRArray, expected);

julia> function f(params, expected)
           return sum(abs.(expected - params))
       end
f (generic function with 1 method)

julia> function ∇f(params, expected)
           foo = Enzyme.gradient(ReverseWithPrimal, f, params, Enzyme.Const(expected))
           return foo.val, foo.derivs[1]
       end
∇f (generic function with 1 method)

julia> params_real = rand(N);

julia> expected_real = rand(N);

julia> params_real′ = adapt(ConcreteRArray, params_real);

julia> expected_real′ = adapt(ConcreteRArray, expected_real);

julia> ∇fR = Reactant.@compile ∇f(params_real′, expected_real′)
ERROR: invalid redefinition of constant Main.f
Stacktrace:
 [1] top-level scope
   @ ~/.julia/packages/Reactant/e7PeE/src/Compiler.jl:488

However, it works when using a function name different than f, like f1:

julia> params_real = rand(N);

julia> expected_real = rand(N);

julia> params_real′ = adapt(ConcreteRArray, params_real);

julia> expected_real′ = adapt(ConcreteRArray, expected_real);

julia> function f1(params, expected)
           return sum(abs.(expected - params))
       end
f1 (generic function with 1 method)

julia> function ∇f(params, expected)
           foo = Enzyme.gradient(ReverseWithPrimal, f1, params, Enzyme.Const(expected))
           return foo.val, foo.derivs[1]
       end
∇f (generic function with 1 method)

julia> ∇fR = Reactant.@compile ∇f(params_real′, expected_real′)
Reactant.Compiler.Thunk{Symbol("##∇f_reactant#225")}()

julia> ∇fR(params_real′, expected_real′)
(ConcreteRNumber{Float64}(2.9273890179275863), ConcreteRArray{Float64, 1}([-1.0, -1.0, -1.0, 1.0, -1.0, 1.0, 1.0, 1.0, -1.0, -1.0]))

julia> ∇f(params_real, expected_real)
(2.9273890179275868, [-1.0, -1.0, -1.0, 1.0, -1.0, 1.0, 1.0, 1.0, -1.0, -1.0])
avik-pal commented 5 days ago

We need to use gensym in the compile macro 😓

mofeing commented 5 days ago

yep 😅