compintell / Tapir.jl

https://compintell.github.io/Tapir.jl/
MIT License
91 stars 1 forks source link

Diffractor + Tapir for computing hessian #157

Open yebai opened 1 month ago

yebai commented 1 month ago

This example currently fails:


julia> import ForwardDiff, Tapir

julia> using DifferentiationInterface

julia> b = SecondOrder(AutoForwardDiff(), AutoTapir());

julia> hessian(sum, b, [1.0, 2.0])
[ Info: Compiling rule for Tuple{typeof(sum), Vector{ForwardDiff.Dual{ForwardDiff.Tag{DifferentiationInterface.var"#inner_gradient_closure#27"{typeof(sum), SecondOrder{AutoForwardDiff{nothing, Nothing}, AutoTapir}}, Float64}, Float64, 1}}} in safe mode. Disable for best performance.
ERROR: MethodError: Cannot `convert` an object of type ForwardDiff.Dual{ForwardDiff.Tag{DifferentiationInterface.var"#inner_gradient_closure#27"{typeof(sum), SecondOrder{AutoForwardDiff{nothing, Nothing}, AutoTapir}}, Float64}, Float64, 1} to an object of type Tapir.Tangent{@NamedTuple{value::Float64, partials::Tapir.Tangent{@NamedTuple{values::Tuple{Float64}}}}}

Closest candidates are:
  convert(::Type{T}, ::T) where T
   @ Base Base.jl:84
  (::Type{Tapir.Tangent{Tfields}} where Tfields<:NamedTuple)(::Any)
   @ Tapir ~/.julia/packages/Tapir/O4V78/src/tangents.jl:53

It's a pity since Tapir would provide a great tool for computing second-order derivatives in conjunction with ForwarDiff. Could this be improved?

Package environment:

(@v1.10) pkg> st Tapir DifferentiationInterface ForwardDiff
Status `~/.julia/environments/v1.10/Project.toml`
  [a0c0ee7d] DifferentiationInterface v0.4.0
  [f6369f11] ForwardDiff v0.10.36
  [07d77754] Tapir v0.2.12
willtebbutt commented 1 month ago

Could you please provide the entire stack trace, or is this it?

In general, I would say that the chances are slim that we'll be able to do ForwardDiff-over-Tapir without a really substantial effort -- to be honest I'm not sure that it is feasible at all because Tapir does a lot of concrete typing, meaning that Duals cannot propagate. It might be possible to make something like Diffractor work, but it would require a large time investment to figure out how to make this all work given that we're using OpaqueClosures for everything (the standard code lookup that Diffractor has to do almost certainly won't work out-of-the-box).

Tapir-over-ForwardDiff (which I think is what you're doing here) might be able to work, as Tapir ought to be able to differentiate things that ForwardDiff does. I'm not sure that this is really the interesting way round to do things though (I think you normall want to do forward-over-reverse). In any case, I'd have to see the stack trace to know more.

yebai commented 1 month ago

given that we're using OpaqueClosures for everything

Out of curiosity, what mechanism is Diffractor using, and what are the differences between Diffractor and Tapir?

Tapir-over-ForwardDiff (which I think is what you're doing here) might be able to work

It is ForwardDiff-over-Tapir IIUC, see here for more details.

willtebbutt commented 1 month ago

Out of curiosity, what mechanism is Diffractor using, and what are the differences between Diffractor and Tapir?

Ah, sorry, it's more that I mean it seems more likely to me that Diffractor-over-Tapir can be made to work, rather than ForwardDiff-over-Tapir. Diffractor also makes use of OpaqueClosures, but it doesn't make use of Dual numbers.

To be clear, I think it would be a substantial piece of work to make Diffractor-over-Tapir work, because you'd have to figure out how to get Diffractor to differentiate through OpaqueClosures (we'd have to find a way to give the IR used to generate the OpaqueClosure to Diffractor). Maybe @oxinabox has thoughts on this? In principle it ought to work nicely though, because both frameworks (if I've understood what Diffractor is doing properly) place few restrictions on the Julia IR that they can work with.

It is ForwardDiff-over-Tapir IIUC, see here for more details.

Cool, thanks -- I'll take a look at this at some point.

yebai commented 1 month ago

To be clear, I think it would be a substantial piece of work to make Diffractor-over-Tapir work,

It sounds like a good opportunity for collaboration and a use case for Diffractor. Although Diffractor might implement its own reverse mode in the future, in its current form, making the Diffractor interoperable with Tapir would benefit both packages.

oxinabox commented 1 month ago

In general I have spent a fair amount of time in the last year making sure that Diffractor-over-Diffractor and Diffractor-over-ForwardDiff works and is fast. As such we should need be able to get Diffractor-over-Tapir to work (or the reverse). Diffractor is good at compiling itself out of existance so you hopefully in many cases can't even tell the code was run through Diffractor.

It should be just a matter of stashing the IR somewhere and teaching the other package what to do with it. I suggest something like replace all OpaqueClosures with a struct

struct MistyClosure{OC}<:Function
    oc::OC
    ir::IRCode
end
MistyClosure(ir) = MistyClosure(OpaqueClosure(ir), ir)
(this::MistyClosure)(args...; kwargs...) = this.oc(args...; kwargs...)

Then with something more or less with something like a frule or rrule!! it instruct it that what to do to AD a call to (this::MistyClosure)(args...; kwargs...) is

  1. go get the IR from the this.ir
  2. run the IR transform based AD pass
willtebbutt commented 1 month ago

Ahh excellent -- I like this idea.

Maybe the way forward would be to:

  1. create a package for misty closures,
  2. replace OpaqueClosures in Diffractor and Tapir with MistyClosures,
  3. ensure that Diffractor and Tapir know how to differentiate MistyClosures, which I agree ought to be straightforward, and
  4. apply Diffractor over Tapir, and get second order stuff.

Is this something you would be interested in collaborating on to make happen?

oxinabox commented 1 month ago

Diffractor's public API doesn't use OpaqueClosures. IIRC. That's only for our forward&demand stuff though that is what I would get using to AD tapir if we used MistyClosures. So I think as long as Tapir emits MistyClosures then for now we are good.

willtebbutt commented 1 month ago

So I think as long as Tapir emits MistyClosures then for now we are good.

Cool. I propose we do the following sequence of things:

  1. I create a package called MistyClosures.jl (you can do this if you like @oxinabox , but I'm happy to take care of it)
  2. @oxinabox you add a small test to Diffractor.jl to make sure that it can differentiate e.g. a MistyClosure which is equivalent to the identity function
  3. I'll change over the OpaqueClosures in this package to use MistyClosures
  4. We work through whatever issues remain in getting Diffractor to work over Tapir (🤞 they are only minor)

Do you think this plan makes sense @oxinabox ?

oxinabox commented 1 month ago

I think this makes sense.

gdalle commented 1 month ago

What about forward Enzyme over Tapir?

willtebbutt commented 1 month ago

Conceptually the same idea, but I don't know whether Enzyme is able to deal with OpaqueClosures / whether we can use the same MistyClosure trick, because Enzyme needs the LLVM code that Julia code gets lowered into.

gdalle commented 1 month ago

Well if you ever need second order and are okay with approximations, something that will surely work is FiniteDifferences over Tapir

willtebbutt commented 1 month ago

That's true. I do think we'll be able to make Diffractor-over-Tapir work nicely though.

gdalle commented 1 month ago

Beware that if you want to test with DifferentiationInterface, you'll be stuck with an old version of Diffractor until the following issue is resolved:

willtebbutt commented 1 month ago

Update: MistyClosures.jl now exists.