EnzymeAD / Reactant.jl

MIT License
69 stars 7 forks source link

Incorrect code-generation for `@trace for ...` #301

Open avik-pal opened 1 week ago

avik-pal commented 1 week ago
@macroexpand @trace for i in 1:10
    dx = 1
end
quote
    #= /mnt/software/lux/Reactant.jl/lib/ReactantCore/src/ReactantCore.jl:183 =#
    if any(ReactantCore.is_traced, (dx,))
        #= /mnt/software/lux/Reactant.jl/lib/ReactantCore/src/ReactantCore.jl:184 =#
        begin
            #= /mnt/software/lux/Reactant.jl/lib/ReactantCore/src/ReactantCore.jl:160 =#
            let args = (Reactant.promote_to(Reactant.TracedRNumber{Int}, 0), dx)
                #= /mnt/software/lux/Reactant.jl/lib/ReactantCore/src/ReactantCore.jl:161 =#
                cond_fn = ((var"##i#957892", dx)->begin
                            #= /mnt/software/lux/Reactant.jl/lib/ReactantCore/src/ReactantCore.jl:162 =#
                            #= /mnt/software/lux/Reactant.jl/lib/ReactantCore/src/ReactantCore.jl:163 =#
                            local num_iters = div(10 - 1, 1, RoundDown)
                            #= /mnt/software/lux/Reactant.jl/lib/ReactantCore/src/ReactantCore.jl:164 =#
                            local num_iters = Reactant.promote_to(Reactant.TracedRNumber{Int64}, num_iters)
                            #= /mnt/software/lux/Reactant.jl/lib/ReactantCore/src/ReactantCore.jl:167 =#
                            var"##i#957892" < num_iters + 1
                        end)
                #= /mnt/software/lux/Reactant.jl/lib/ReactantCore/src/ReactantCore.jl:169 =#
                body_fn = ((var"##i#957892", dx)->begin
                            #= /mnt/software/lux/Reactant.jl/lib/ReactantCore/src/ReactantCore.jl:170 =#
                            #= /mnt/software/lux/Reactant.jl/lib/ReactantCore/src/ReactantCore.jl:171 =#
                            local step_ = 1
                            #= /mnt/software/lux/Reactant.jl/lib/ReactantCore/src/ReactantCore.jl:172 =#
                            local start_ = 1
                            #= /mnt/software/lux/Reactant.jl/lib/ReactantCore/src/ReactantCore.jl:173 =#
                            local i = start_ + var"##i#957892" * step_
                            #= /mnt/software/lux/Reactant.jl/lib/ReactantCore/src/ReactantCore.jl:174 =#
                            begin
                                #= REPL[54]:2 =#
                                dx = 1
                                #= REPL[54]:3 =#
                            end
                            #= /mnt/software/lux/Reactant.jl/lib/ReactantCore/src/ReactantCore.jl:175 =#
                            (var"##i#957892" + 1, dx)
                        end)
                #= /mnt/software/lux/Reactant.jl/lib/ReactantCore/src/ReactantCore.jl:178 =#
                (ReactantCore).traced_while(cond_fn, body_fn, args)
            end
        end
    else
        #= /mnt/software/lux/Reactant.jl/lib/ReactantCore/src/ReactantCore.jl:186 =#
        for i = 1:10
            #= REPL[54]:2 =#
            dx = 1
            #= REPL[54]:3 =#
        end
    end
end
args = (Reactant.promote_to(Reactant.TracedRNumber{Int}, 0), dx)

dx is not necessarily defined outside the forloop

jumerckx commented 6 days ago

Could this work?

struct LoopInitializer end

let
    my_dx = LoopInitializer()
    if @isdefined dx
        my_dx = dx
    end
    args = (0, my_dx)

    cond_fn = (i, my_dx)->begin
        local num_iters = 9
        i < num_iters + 1
    end
    body_fn = (i, my_dx)->begin
        local step_ = 1
        local start_ = 1
        begin
            dx = 1
        end
        if !(my_dx isa LoopInitializer)
            my_dx = dx
        end
        (i + 1, my_dx)
    end
    cond_fn(args...), body_fn(args...)
end

The check at the end of body_fn should always be optimized away, I believe. If there's no obvious problems with this approach I can try add it to the macro, unless someone else would like to?