FluxML / IRTools.jl

Mike's Little Intermediate Representation
MIT License
111 stars 36 forks source link

UndefVarError when referring to struct in closure dynamo #51

Open femtomc opened 4 years ago

femtomc commented 4 years ago

Basically I can write:

module ClosureScratch

using IRTools
using IRTools: IR, @dynamo

mutable struct Counter
    count::Int
end

@dynamo function (c::Counter)(m...)
    ir = IR(m...)
    for (v, st) in ir
        c.count += 1
    end
    return ir
end

function foo(x::Int)
    y = x + 5
    return y
end

c = Counter(1)

c() do
    foo(5)
end

println(@code_ir c foo(5))
println(c)

end #module

and this will work, you can refer to c inside @dynamo.

But if I define the dynamo elsewhere in the library, I run into issues...

ERROR: LoadError: Error compiling @dynamo Main.TraceTransform.Jaynes.Trace on (Main.TraceTransform.var"#1#2",):
UndefVarError: tr not defined

where I've defined a dynamo inside my library as

@dynamo function (tr::Trace)(m...)
    ir = IR(m...)
    println(tr)
    ir == nothing && return
    recurse!(ir)
    return ir
end
MikeInnes commented 4 years ago

The first example seems like it must be an IRTools bug. @dynamos don't have values available (since they run at compile time), so the object c should refer to the type of c that you passed in, not the actual counter object (just like the other arguments to the dynamo do).

I guess in @code_ir we're wrongly making the self argument available, whereas during actual compilation we ignore the self argument entirely and should instead expose tr = typeof(tr).

We could fix these but I suspect you don't actually want that fix, based on the examples you've shown, so we might need to just discuss the use case a bit.

femtomc commented 4 years ago

Okay cool! It's very likely I'm trying to implement something in the wrong way. Basically, I've been trying to use a struct to control the IR transformation. Idiomatically, that would mean I'm writing something like

@dynamo function (trace::Trace)(m...)
    ir = IR(m...)
    ir == nothing && return
    recurse!(ir)
    return ir
end

function (tr::Trace)(call::typeof(rand), dist::T) where T <: Distribution
    result = call(dist)
    score = logpdf(dist, result)
    record!(tr, _address goes here_, dist, result, score)
    return result
end

but here the dispatch on typeof(rand) and dist only has local information available from the right hand side of that particular IR assignment so I can't do anything clever with address goes here.

I would really like to be able to fill that address goes here with the symbol (assuming that it exists) from the left hand side of the original code. So I started trying a few things ... I started by using @code_lowered to grab slotnames from the program and then pairing them with IR variable names. But this doesn't work with dispatch because dispatch only knows about the right hand side call. Then, I started messing around in the dynamo - to pass that sort of address information (which I pre-compute and store in the Trace object which I'm using) I'm trying to access the field inside the dynamo.

I get the feeling there's a more elegant way to do this without the dynamo but I'm curious to hear your thoughts.