JuliaDiff / ChainRulesCore.jl

AD-backend agnostic system defining custom forward and reverse mode rules. This is the light weight core to allow you to define rules for your functions in your packages, without depending on any particular AD system.
Other
256 stars 62 forks source link

Typed thunks? #373

Open oschulz opened 3 years ago

oschulz commented 3 years ago

The discussion on FluxML/Zygote.jl#966 - thunks have really evolved into monads now, right? Especially now that we're adding methods for them to lots of linear algebra and other functions. Shouldn't AbstractThunk have a type parameter, then? Kinda like

abstract type AbstractThunk{T} end

struct Thunk{T,F<:Base.Callable} <: AbstractThunk{T}
    body::F
end

unthunk(t::Thunk{T}) where T = t.body()::T

Base.eltype(t::Thunk{T}) where T = T

macro thunk(body)
    func = Expr(:->, Expr(:tuple), Expr(:block, __source__, body))
    return quote
        f = $(esc(func))
        Thunk{Base._return_type(f,()), typeof(f)}(f)
    end
end

thnk = let A = rand(5,5), B = rand(5,5)
    @thunk A * B
end

eltype(thnk) == typeof(unthunk(thnk))

I've seen type inference fail with thunks quite a few times - maybe having typed thunks would help with that, too?

oxinabox commented 3 years ago

If type inference is failing on unthunk(::Thunk{F}), then I would expect that it also fails on the function behind the thunk (F.instance). and so the use of Base._return_type (or Core.Compliler.return_type) will return Any

oxinabox commented 3 years ago

It is pretty important that thunks infer. Possibly important enough that we should have debug_mode cause creating a thunk that doesn't infer (according to return_type(unthunk(t)) give a warning or an error.

Thunks should always infer, they have all the things they need. And if they don't the fix often (always?) is a change in the thunk's body. Possibly if it comes to it adding some :: type asserts.

Storing a type-along side the thunk shouldn't be for inference helping, but it might be useful for dispatch, eg if you note that it is a Thunk{AbstractArray} vs a Thunk{Tangent} but we seem to be able to handle that fine by redistpatching after the unthunk (albeit at the cost of doing an unthunk)

oschulz commented 3 years ago

Thunks should always infer, they have all the things they need. And if they don't the fix often (always?) is a change in the thunk's body.

In ForwardDiffPullbacks, I actually had to create a custom FwdDiffPullbackThunk and FwdDiffBCPullbackThunk to help type inference along.

oschulz commented 3 years ago

Storing a type-along side the thunk shouldn't be for inference helping, but it might be useful for dispatch, eg if you note that it is a Thunk{AbstractArray} vs a Thunk{Tangent} but we seem to be able to handle that fine by redistpatching after the unthunk (albeit at the cost of doing an unthunk)

Yes, that would also allow us to do more fine-grained methods supporting thunks to common functions (related to the recent work of @mzgubic). And avoiding some more unthunks would certainly be nice.