EnzymeAD / Reactant.jl

MIT License
65 stars 6 forks source link

passing same argument twice leads to a bounds error while compiling #226

Open avik-pal opened 1 week ago

avik-pal commented 1 week ago
julia> using Reactant

julia> foo(x, y) = x
foo (generic function with 1 method)

julia> x_ra = Reactant.to_rarray(rand(2, 2))
2×2 ConcreteRArray{Float64, 2}:
 0.806331  0.636137
 0.349127  0.0385477

julia> @compile foo(x_ra, x_ra)
ERROR: BoundsError: attempt to access Tuple{} at index [2]
Stacktrace:
 [1] getindex(t::Tuple, i::Int64)
   @ Base ./tuple.jl:31
 [2] codegen_flatten!(linear_args::Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, result_stores::Dict{Tuple, Symbol})
   @ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:537
 [3] compile(f::Function, args::Tuple{ConcreteRArray{Float64, 2}, ConcreteRArray{Float64, 2}}; client::Nothing, optimize::Bool, sync::Bool)
   @ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:733
 [4] top-level scope
   @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:457
 [5] top-level scope
   @ none:1

julia> x_ra2 = Reactant.to_rarray(rand(2, 2))
2×2 ConcreteRArray{Float64, 2}:
 0.806331  0.636137
 0.349127  0.0385477

julia> @compile foo(x_ra, x_ra2)
Reactant.Compiler.Thunk{Symbol("##foo_reactant#16369")}()
wsmoses commented 1 week ago

ah so I think the flatten code doesn't do the same arg deduplication that we do elsewhere (and should also fix to keep the deduplication)