JuliaLabs / Cassette.jl

Overdub Your Julia Code
Other
371 stars 35 forks source link

Open discussion - support for dynamic pass creation #175

Open femtomc opened 4 years ago

femtomc commented 4 years ago

For a probabilistic programming library I’ve been developing, I’ve identified a set of optimizations which require something akin to the ability to dynamic create and register passes inside overdub.

In essence, depending on context metadata, I would like the pass to do different things e.g. depending on the runtime dependency graph for randomness (which could change as I run the program) I would like to configure a pass to remove parts of the program which are irrelevant (i.e. incremental computation). This pass would be configured by runtime data.

The alternative to doing this in Cassette is to do this with IRTools: run a pass, dynamically compile a new method body to an anonymous function, then recurse into an invokelatest call. I’m a little concerned about this alternative, because it appears sub-optimal from a performance perspective.

The other rock in the shoe is the caching of pass results - i.e. if I see the same runtime data, I would like to re-use the results from a previously computed pass.

Is something like this possible with the current compiler infrastructure? Is there an optimal way to do this?

vchuravy commented 4 years ago

This is very hard to do performantly. The pass operates as part of a generated function which only receives the types of the arguments. In theory you can hoist runtime values into the typesystem, but that will create a dynamic dispatch site and poison type information.

julia> using Cassette
[ Info: Precompiling Cassette [7057c7e9-c182-5462-911a-8362d720325c]

julia> Cassette.@context Ctx
Cassette.Context{nametype(Ctx),M,T,P,B,H} where H<:Union{Cassette.DisableHooks, Nothing} where B<:Union{Nothing, IdDict{Module,Dict{Symbol,Cassette.BindingMeta}}} where P<:Cassette.AbstractPass where T<:Union{Nothing, Cassette.Tag} where M

julia> ctx = Ctx(metadata=Val(1))
Cassette.Context{nametype(Ctx),Val{1},Nothing,Cassette.var"##PassType#255",Nothing,Nothing}(nametype(Ctx)(), Val{1}(), nothing, Cassette.var"##PassType#255"(), nothing, nothing)

julia> typeof(ctx)
Cassette.Context{nametype(Ctx),Val{1},Nothing,Cassette.var"##PassType#255",Nothing,Nothing}

Now you can write a pass:

function transform(::Type{Cassette.Context{N, Val{V}}}, reflection) where {N, V}

end

The real question will be when you change the metadata and how you decide to do so. You are effectively creating an intentional compiler barrier and then re-entering the compiler at that callsite. Julia will cache that correctly but you can also shot yourself into the foot rather easily. I do not recommend doing this.

femtomc commented 4 years ago

Okay, so that's a big red don't do it.

The alternative is something like this:

function Cassette.overdub(ctx::DefaultCtx, fn::typeof(some_foo), args...)
    t_args = map(args) do a
        typeof(a)
    end
    ir = lower_to_ir(fn, t_args...)
    ir = some_transform(ir, ctx.metadata)
    ret = Base.invokelatest(IRTools.func(Main, ir), nothing, args...)
    return ret
end

which appears to work, and would allow me to implement the functionality I want. It also looks like I could get the caching working.

However, are there performance gotchas with Base.invokelatest? It looks like it is ill-advised to use this too frequently, but I also don't know how to get around generated world age problems without it (especially here, where you hit a world age problem directly when you compile the ir into an anonymous function.

This solution is not Cassette specific. The question I have is whether or not this sort of dynamic munging capability is even a good idea for Julia - there are highly elaborate workarounds currently e.g. Gen has constructed a special IR + generated functions to produce method bodies which incrementally compute/re-compute what needs to be re-computed.

It seems like incremental computation is a good use case for why you might want this capability, but I'm unsure if there's an idiomatic way of expressing this without reflection and munging.

vchuravy commented 4 years ago

That is better written as:

function Cassette.overdub(ctx::DefaultCtx, fn::typeof(some_foo), args...)
     new_ctx = DefaultCtx() # change metadata, change pass, change context type
     recurse(ctx, fn, args...)
end