Open femtomc opened 4 years ago
This is very hard to do performantly. The pass operates as part of a generated function which only receives the types of the arguments. In theory you can hoist runtime values into the typesystem, but that will create a dynamic dispatch site and poison type information.
julia> using Cassette
[ Info: Precompiling Cassette [7057c7e9-c182-5462-911a-8362d720325c]
julia> Cassette.@context Ctx
Cassette.Context{nametype(Ctx),M,T,P,B,H} where H<:Union{Cassette.DisableHooks, Nothing} where B<:Union{Nothing, IdDict{Module,Dict{Symbol,Cassette.BindingMeta}}} where P<:Cassette.AbstractPass where T<:Union{Nothing, Cassette.Tag} where M
julia> ctx = Ctx(metadata=Val(1))
Cassette.Context{nametype(Ctx),Val{1},Nothing,Cassette.var"##PassType#255",Nothing,Nothing}(nametype(Ctx)(), Val{1}(), nothing, Cassette.var"##PassType#255"(), nothing, nothing)
julia> typeof(ctx)
Cassette.Context{nametype(Ctx),Val{1},Nothing,Cassette.var"##PassType#255",Nothing,Nothing}
Now you can write a pass:
function transform(::Type{Cassette.Context{N, Val{V}}}, reflection) where {N, V}
end
The real question will be when you change the metadata and how you decide to do so. You are effectively creating an intentional compiler barrier and then re-entering the compiler at that callsite. Julia will cache that correctly but you can also shot yourself into the foot rather easily. I do not recommend doing this.
Okay, so that's a big red don't do it
.
The alternative is something like this:
function Cassette.overdub(ctx::DefaultCtx, fn::typeof(some_foo), args...)
t_args = map(args) do a
typeof(a)
end
ir = lower_to_ir(fn, t_args...)
ir = some_transform(ir, ctx.metadata)
ret = Base.invokelatest(IRTools.func(Main, ir), nothing, args...)
return ret
end
which appears to work, and would allow me to implement the functionality I want. It also looks like I could get the caching working.
However, are there performance gotchas with Base.invokelatest
? It looks like it is ill-advised to use this too frequently, but I also don't know how to get around generated
world age problems without it (especially here, where you hit a world age problem directly when you compile the ir
into an anonymous function.
This solution is not Cassette
specific. The question I have is whether or not this sort of dynamic munging capability is even a good idea for Julia - there are highly elaborate workarounds currently e.g. Gen has constructed a special IR + generated functions to produce method bodies which incrementally compute/re-compute what needs to be re-computed.
It seems like incremental computation is a good use case for why you might want this capability, but I'm unsure if there's an idiomatic way of expressing this without reflection and munging.
That is better written as:
function Cassette.overdub(ctx::DefaultCtx, fn::typeof(some_foo), args...)
new_ctx = DefaultCtx() # change metadata, change pass, change context type
recurse(ctx, fn, args...)
end
For a probabilistic programming library I’ve been developing, I’ve identified a set of optimizations which require something akin to the ability to dynamic create and register passes inside overdub.
In essence, depending on context metadata, I would like the pass to do different things e.g. depending on the runtime dependency graph for randomness (which could change as I run the program) I would like to configure a pass to remove parts of the program which are irrelevant (i.e. incremental computation). This pass would be configured by runtime data.
The alternative to doing this in Cassette is to do this with IRTools: run a pass, dynamically compile a new method body to an anonymous function, then recurse into an invokelatest call. I’m a little concerned about this alternative, because it appears sub-optimal from a performance perspective.
The other rock in the shoe is the caching of pass results - i.e. if I see the same runtime data, I would like to re-use the results from a previously computed pass.
Is something like this possible with the current compiler infrastructure? Is there an optimal way to do this?