JuliaGaussianProcesses / KernelFunctions.jl

Julia package for kernel functions for machine learning
https://juliagaussianprocesses.github.io/KernelFunctions.jl/stable/
MIT License
266 stars 32 forks source link

Known AD Failures #116

Open theogf opened 4 years ago

theogf commented 4 years ago

Here is a list of the failures in the tests made in #114 I observed with the different ADs : ForwardDiff.jl, Zygote.jl and ReverseDiff.jl :

This is a good starting point to try to find solutions

devmotion commented 4 years ago

The Zygote problems with MaternKernel are caused by the fact that the partial derivative of besselk with respect to the first argument is defined as NaN in https://github.com/JuliaDiff/ChainRules.jl/blob/98c54587257b86cce6eb45f7870a75f897058d21/src/rulesets/packages/SpecialFunctions.jl#L46-L47 (and I assume the same problem exists for the other AD backends, since I get NaN for all of them when I try to run the commented out AD tests). I guess one would have to implement https://dlmf.nist.gov/10.38 to fix it.

theogf commented 4 years ago

Haha writing these derivatives sounds like one should write a whole package about bessel functions

yebai commented 4 years ago

@sharanry Can you prioritise these AD issues? It would be great if these issues can be addressed during the summer.

devmotion commented 4 years ago

BTW I found some publication from 2016 with closed-form expressions of the derivatives of the Bessel functions with respect to the order. I opened an issue at https://github.com/JuliaDiff/ChainRules.jl/issues/208 to discuss how one would deal with the additional dependencies needed for their implementations (they contain hypergeometric functions).

devmotion commented 4 years ago

We might want to refactor KernelSum and KernelProd (making them concretely typed and allowing both tuples and vectors of kernels similar to TensorProduct, and probably removing the weights in KernelSum) before fixing any AD issues there.

theogf commented 4 years ago

Agreed! There is also a general AD issue when using Transform where the pullback on ColVecs and RowVecs return a vector of vectors, this would tick off a good portions of the issues.

sharanry commented 4 years ago

@sharanry Can you prioritise these AD issues? It would be great if these issues can be addressed during the summer.

Sorry for the late reply. I somehow didn't get a notification this comment. Randomly found this while browsing the issues. I am looking into it.

sharanry commented 4 years ago

The probable reason Zygote fails for FunctionTransform is the usage of Base.mapslices in https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/blob/0df9e8352d7d9f034f75de4d3639c0bb3b96c714/src/transform/functiontransform.jl#L19-L20.

Base.mapslices is mutating the array. Not sure why. https://github.com/JuliaLang/julia/pull/17266 should have fixed this.

julia> Zygote._pullback(x-> mapslices(x->sin.(x), x, dims=1), rand(3,3))[2](ones(3,3))
ERROR: Mutating arrays is not supported
devmotion commented 4 years ago

mapslices still mutates a temporary array. The linked PR just ensures that

mapslices never modifies the input array. It allocates temporary storage and copies each slice into it before calling the user-function.

sharanry commented 4 years ago

mapslices still mutates a temporary array. The linked PR just ensures that

Oh makes sense. Do you see any other efficient way to apply a function transform for a matrix/ColVecs/RowVecs?

devmotion commented 4 years ago

I see the following possibilities here:

sharanry commented 4 years ago

I just ran a quick benchmark for one other possibility which would require us to define adjoint for a generator. The methods you mentioned are probably better.

julia> @btime hcat(map(x->sin.(x), (eachslice(rand(1000,1000); dims=1)))...)
  16.586 ms (2015 allocations: 23.09 MiB)
julia> @btime mapslices(x->sin.(x), rand(1000,1000); dims=1)
  12.189 ms (7505 allocations: 23.18 MiB)
devmotion commented 4 years ago

A bit off topic, but splatting probably impacts performance quite a bit, so probably it would b better to use mapreduce(x -> sin.(x), hcat, ...). For benchmarks you also want to use $(rand(1000, 1000)) (in that way the timings are unaffected by the calls of rand).

sharanry commented 4 years ago

A bit off topic, but splatting probably impacts performance quite a bit, so probably it would b better to use mapreduce(x -> sin.(x), hcat, ...). For benchmarks you also want to use $(rand(1000, 1000)) (in that way the timings are unaffected by the calls of rand).

Thanks! Wasn't aware of this. This however gave unexpected results for mapreduce.

julia> @btime mapslices(x->sin.(x), $(rand(1000,1000)); dims=1);
  10.581 ms (7503 allocations: 15.55 MiB)

julia> @btime mapreduce(x->sin.(x), hcat, eachslice($(rand(1000,1000)); dims=1));
  914.970 ms (5002 allocations: 3.74 GiB)
devmotion commented 4 years ago

Shouldn't you use eachslice(...; dims=2) or eachcol?

sharanry commented 4 years ago

I don't think it is making much difference performance wise at least.

julia> @btime mapreduce(x->sin.(x), hcat, eachslice($(rand(1000,1000)); dims=2));
  952.564 ms (5002 allocations: 3.74 GiB)
devmotion commented 4 years ago

Can you check if the function is typestable? I suspect it might not, which would explain the number of allocations. The problem might be that it returns a different type if eachslice(...) is empty. Specifying an init kwarg might be helpful.

willtebbutt commented 4 years ago

Just for the record -- we should be using ChainRulesCore to define pullbacks for Zygote, and ChainRulesTestUtils to test those implementations -- see e.g. here for example usage.

Plans are in the works to transfer both Tracker and ReverseDiff over to use ChainRules at some point (we know how we're going to do it, just waiting for code to get written), so this will future-proof AD in the package.

yiyuezhuo commented 4 years ago

I can't figure out how to define mapslices adjoint only in ChainRulesCore. If f is an anonymous function, how can we get its backward (rrule) from ChainRulesCore, while it can be done in Zygote using gradient?

devmotion commented 4 years ago

KernelFunctions doesn't use mapslices anymore, so for this projects custom adjoints for mapslices are not required anymore. Nevertheless, an implementation of an adjoint of mapslices in ChainRules requires a solution to https://github.com/JuliaDiff/ChainRulesCore.jl/issues/68 AFAICT.

yiyuezhuo commented 4 years ago

I see. I checked the source of the last release to backport TransformedKernel to Stheno as Stheno doesn't support KernelFunctions and found those mapslices code. But after thinking how to implement it in ChainRules or Zygote, I just disable a check in Stheno to re-enable gradient of f.(ColVecs(X)) since I feel ChainRules or Zygote will not too much difference.

devmotion commented 4 years ago

The latest releases don't use mapslices, it was replaced in https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/pull/152. I guess you can use something similar instead of mapslices in Stheno as well.

sharanry commented 4 years ago

Regarding FBMKernel not working with ForwardDiff. It seems to be producing NaN values incorrectly. According to the ForwardDiff documentation, the fix for this is to "enable ForwardDiff's NaN-safe mode by setting the NANSAFE_MODE_ENABLED constant to true in ForwardDiff's source". They are currently not allowing users to enable it dynamically [Issue].

devmotion commented 4 years ago

You can use https://github.com/JuliaDiff/ForwardDiff.jl/pull/451 if you do not want to edit the source code.

cgeoga commented 2 years ago

Hi--it just occurs to me to share this here, but I recently finished a project for computing derivatives of besselk with respect to the order parameter precisely for the purpose of fitting Matern covariances (example Matern kernel implementation here). The strategy that worked ended up being a re-implementation of besselk in Julia that admits fast and accurate AD derivatives with ForwardDiff.jl. The re-implemented besselk itself is not quite as accurate as the AMOS one linked in SpecialFunctions, but the derivatives are pretty accurate. Not quite to machine double precision, but reasonably close. And very fast.

I'm not sure how helpful this is because the derivatives are at present pretty ForwardDiff-specific. I would guess that it would be possible to reach compatibility with other AD tools, perhaps at a slight cost of performance by eliminating some special branches in the current implementation, but I honestly don't understand how Zygote works at all so I can't promise it.

Anyways, just writing here in case it is helpful.

devmotion commented 2 years ago

I came across https://www.tandfonline.com/doi/pdf/10.1080/10652469.2016.1164156 a while ago, it contains closed-form expressions of the derivatives using e.g. hypergeometric functions. In principle these could be used with other AD backends as well but I don't know if there are any numerical problems, how slow/fast the evaluation with HypergeometricFunctions would be, and if (I assume not since it would introduce a circular dependency) SpecialFunctions would take a dependency on HypergeometricFunctions.

cgeoga commented 2 years ago

I also saw that paper and was interested in just using that before undertaking a more from-scratch approach. But there are a few challenges with using the representations in Santander. For one, as you point out, evaluating the generalized hypergeometric functions like 3F4 and 2F3 is a task of comparable difficulty. I love HyperGeometricFunctions.jl, but that's a lot of pressure to put on that package, which at the very least in my experience is very slow when the besselk argument is small (which is unfortunately where the accurate derivatives matter the most). More importantly, though, the representation in Santander is hard in a bunch of edge cases. Like, when $\nu$ is an integer or near-integer, there are several problems, both with cancellations and in trig functions blowing up. If nu = 1 + 1e-8 or something that ostensibly exact equation might give literally zero digits of accuracy. The exact derivatives when nu + 1/2 is a whole integer are particularly gnarly and I've never seen them for any case besides nu=1/2.

Our project was enough of a hassle that we ended up writing a paper about it, and almost all the work was in handling the problems of $\nu$ being exactly or nearly an integer of half-integer. I don't think there's any way around a gnarly branching function to handle the derivatives in those cases. And if you look our timings (table one of the paper), it will probably be hard to come anywhere near those speeds at even comparable accuracy around those edge cases.

I've actually thought about asking the SpecialFunctions package folks if they'd be interested in some of our code being added to that package, but considering that we are a bit cavalier in giving up the last couple digits of accuracy I'm a bit concerned that it's not a great fit.

In any case, just posting here for your consideration. If somebody manages to implement them with exact expressions in a way that is tolerably fast and handles those edge cases, I'll be the first person to celebrate. In the mean time, though, I wouldn't be shocked if zygote compatibility was possible. I just really don't know enough to conjecture about how much of a project it would be.