JuliaGaussianProcesses / Stheno.jl

Probabilistic Programming with Gaussian processes in Julia
Other
338 stars 26 forks source link

Tests currently broken #242

Closed martincornejo closed 1 year ago

martincornejo commented 1 year ago

Testing the current master branch, with Julia 1.9.0, results in the following:

Pass Error Total Time
1067 10 1077 5m48.0s

It seems some of the broken tests are related to https://github.com/JuliaGaussianProcesses/AbstractGPs.jl/issues/356, while others are caused by Zygote.

versioninfo()
Julia Version 1.9.0
Commit 8e63055292 (2023-05-07 11:25 UTC)
Platform Info:
  OS: Windows (x86_64-w64-mingw32)
  CPU: 8 × Intel(R) Core(TM) i5-8365U CPU @ 1.60GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.6 (ORCJIT, skylake)
  Threads: 4 on 8 virtual cores
Environment:
  JULIA_EDITOR = code
  JULIA_NUM_THREADS = 4
martincornejo commented 1 year ago

This is the error message of some broken tests

GP: Error During Test at C:\Users\Cornejo\Documents\GitHub\Stheno.jl\test\gp\atomic_gp.jl:14
  Test threw exception
  Expression: mean(f, x) == AbstractGPs._map_meanfunction(m, x)
  UndefVarError: `_map_meanfunction` not defined
martincornejo commented 1 year ago

Minimal example to reproduce the Zygote errors:

using Stheno
import Zygote

f = @gppp let
    f1 = GP(SEKernel())
    f2 = GP(Matern52Kernel())
    f3 = f1 + f2
end

x = GPPPInput(:f3, randn(5))
y = randn(5)
Zygote.pullback(mean, f3, y)

Edit: simplified the example

willtebbutt commented 1 year ago

Ugh -- this second case looks like an example of someone having written an rrule which makes assumptions that are too strong. Debugging now.

martincornejo commented 1 year ago

The Statcktrace leads here?

https://github.com/JuliaGaussianProcesses/Stheno.jl/blob/bd156542e88b22887c595c58d53c06cf03e8622a/src/gp/derived_gp.jl#L19-L19

Zygote is somehow overseeing the following method definition for mean: https://github.com/JuliaGaussianProcesses/Stheno.jl/blob/bd156542e88b22887c595c58d53c06cf03e8622a/src/affine_transformations/addition.jl#L20-L22

martincornejo commented 1 year ago

This is the potential the commit that introduced this behavior in ChainRules: https://github.com/JuliaDiff/ChainRules.jl/commit/8424476fda14585decc476c163afea8e6666b8a7

Committed December 2022 (last commit in Stheno is June 2022)

The rrule for mean is currently defined as following:

function rrule(
    config::RuleConfig{>:HasReverseMode},
    ::typeof(mean),
    f::F,
    x::AbstractArray{T};
    dims=:,
) where {F, T<:Union{Real,Complex,AbstractArray}}
    y_sum, sum_pullback = rrule(config, sum, f, x; dims)
    n = _denom(x, dims)
    function mean_pullback_f(ȳ)
        return sum_pullback(unthunk(ȳ) / n)
    end
    return y_sum / n, mean_pullback_f
end

Probably specifying where {F<:Function, ... would fix it? I will try it out and open a PR if that works.

Edit: Adding that type check is probably not a solution: https://github.com/JuliaDiff/ChainRules.jl/issues/522. ChainRules also tests that non-function callables should also work https://github.com/JuliaDiff/ChainRules.jl/blob/11c230cdf0f37a4f42de909d6c1f8500d1a80d69/test/rulesets/Statistics/statistics.jl#L18

willtebbutt commented 1 year ago

Resolved by #244