JuliaLang / julia

The Julia Programming Language
https://julialang.org/
MIT License
45.54k stars 5.47k forks source link

type inference fails when wrapping data construction into a function #47694

Open schlichtanders opened 1 year ago

schlichtanders commented 1 year ago

Hi there,

I am suprised by some poor inference which happens by introducing a new function layer. I think somehow this triggers a special recursive case which is poorly resolved.

I am the author of ExtensibleEffects.jl and am trying to track down performance problems. I have now been able to produce a minimal example (as minimal as I could get it so far), which at least is self-contained and does not need ExtensibleEffects.

Example

Here is the basics which we need to run the tests. It is a minimal implementation for Algebraic Effects.

struct NoEffect{T}
    value::T
end

struct Eff{Effectful, Fs}
    effectful::Effectful
    funcs::Fs
end
Eff(effectful) = Eff(effectful, ())

# immediately run continuation if there is a continuation
# and the wrapped value has no effect
function Eff(effectful::E, funcs::Tuple{}) where E<:NoEffect
    Eff{E, Tuple{}}(effectful, funcs)
end
function Eff(effectful::NoEffect, funcs::Fs) where Fs
    func = Base.first(funcs)
    rest = Base.tail(funcs)
    eff = func(effectful.value)
    Eff(eff.effectful, (eff.funcs..., rest...))
end

noeffect(value) = Eff(NoEffect(value))
noeffect(eff::Eff) = eff

If you directly use the Eff constructor, everything works nicely

using Test

e1 = noeffect(1)
e2 = noeffect(2)
curried_combine(v1) = v2 -> v1 + v2

function inference_works(e1::E1, e2::E2) where {E1, E2}
    e1_f = Eff(e1.effectful, (e1.funcs..., v1 -> noeffect(curried_combine(v1))))
    f_flatmap(f) = Eff(e2.effectful, (e2.funcs..., v2 -> noeffect(f(v2))))
    Eff(e1_f.effectful, (e1_f.funcs..., f_flatmap))
end

inference_works(e1, e2).effectful.value  # 3
@inferred inference_works(e1, e2)

but if you wrap the typical Eff construction into its own separate function, the inference fails.

function myeff_flatmap(f_flatmap::F, e1_f::E) where {F, E}
    Eff(e1_f.effectful, (e1_f.funcs..., f_flatmap))
end

function inference_fails(e1::E1, e2::E2) where {E1, E2}
    e1_f = myeff_flatmap(v1 -> noeffect(curried_combine(v1)), e1)
    f_flatmap(f) = myeff_flatmap(v2 -> noeffect(f(v2)), e2)
    myeff_flatmap(f_flatmap, e1_f)
end

inference_fails(e1, e2).effectful.value  # 3
@inferred inference_fails(e1, e2) 
# ERROR: return type Eff{NoEffect{Int64}, Tuple{}} does not match inferred return type Eff

I was not able to find anything useful to fix this. I heard about a trick which(myeff_flatmap, Tuple{Any, Any}).recursion_relation = (@nospecialize(_...)) -> true, but I couldn't make it work, if it could work at all.

My best bet is that this has something to do with the point that myeff_flatmap is called recursively within the call.

Workaround

My current workaround is to write the function as a generated function instead, so that the code is inlined, which works as expected and does not require too much boilerplate.

Julia version

julia> versioninfo()
Julia Version 1.8.1
Commit afb6c60d69a (2022-09-06 15:09 UTC)
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 8 × Intel(R) Core(TM) i7-1065G7 CPU @ 1.30GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-13.0.1 (ORCJIT, icelake-client)
  Threads: 1 on 8 virtual cores
schlichtanders commented 1 year ago

@N5N3 do you know why

which(myeff_flatmap, Tuple{Any, Any}).recursion_relation = (@nospecialize(_...)) -> true

does not work here?

N5N3 commented 1 year ago

I didn't test it locally. Looks like myeff_flatmap would cause call chain like Eff -> myeff_flatmap -> Eff -> myeff_flatmap -> ... In this case, you have to loose the recursion check of Eff and myeff_flatmap at the same time to make sure the inference success.

You can use Cthulhu.jl to check if the inference is blocked in Eff. If so, then I think this is a duplicate of #45759. (As inference_works is a good example that the compiler would do eager inference on direct recursion. Eff -> Eff -> Eff -> ...)