Open theogf opened 4 years ago
The Zygote problems with MaternKernel are caused by the fact that the partial derivative of besselk
with respect to the first argument is defined as NaN
in https://github.com/JuliaDiff/ChainRules.jl/blob/98c54587257b86cce6eb45f7870a75f897058d21/src/rulesets/packages/SpecialFunctions.jl#L46-L47 (and I assume the same problem exists for the other AD backends, since I get NaN
for all of them when I try to run the commented out AD tests). I guess one would have to implement https://dlmf.nist.gov/10.38 to fix it.
Haha writing these derivatives sounds like one should write a whole package about bessel functions
@sharanry Can you prioritise these AD issues? It would be great if these issues can be addressed during the summer.
BTW I found some publication from 2016 with closed-form expressions of the derivatives of the Bessel functions with respect to the order. I opened an issue at https://github.com/JuliaDiff/ChainRules.jl/issues/208 to discuss how one would deal with the additional dependencies needed for their implementations (they contain hypergeometric functions).
We might want to refactor KernelSum and KernelProd (making them concretely typed and allowing both tuples and vectors of kernels similar to TensorProduct, and probably removing the weights in KernelSum) before fixing any AD issues there.
Agreed!
There is also a general AD issue when using Transform
where the pullback on ColVecs
and RowVecs
return a vector of vectors, this would tick off a good portions of the issues.
@sharanry Can you prioritise these AD issues? It would be great if these issues can be addressed during the summer.
Sorry for the late reply. I somehow didn't get a notification this comment. Randomly found this while browsing the issues. I am looking into it.
The probable reason Zygote fails for FunctionTransform
is the usage of Base.mapslices
in https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/blob/0df9e8352d7d9f034f75de4d3639c0bb3b96c714/src/transform/functiontransform.jl#L19-L20.
Base.mapslices
is mutating the array. Not sure why. https://github.com/JuliaLang/julia/pull/17266 should have fixed this.
julia> Zygote._pullback(x-> mapslices(x->sin.(x), x, dims=1), rand(3,3))[2](ones(3,3))
ERROR: Mutating arrays is not supported
mapslices
still mutates a temporary array. The linked PR just ensures that
mapslices never modifies the input array. It allocates temporary storage and copies each slice into it before calling the user-function.
mapslices
still mutates a temporary array. The linked PR just ensures that
Oh makes sense. Do you see any other efficient way to apply a function transform for a matrix/ColVecs/RowVecs?
I see the following possibilities here:
_map(::FunctionTransform, ...)
mapslices
in ChainRules (see https://github.com/FluxML/Zygote.jl/issues/92)SliceMap
(see https://github.com/FluxML/Zygote.jl/issues/92)function _map(t::FunctionTransform, x::ColVecs)
vals = map(axes(x.X, 2)) do i
t.f(view(x.X, :, i))
end
return ColVecs(vals)
end
(Zygote should support this automatically)
I just ran a quick benchmark for one other possibility which would require us to define adjoint for a generator. The methods you mentioned are probably better.
julia> @btime hcat(map(x->sin.(x), (eachslice(rand(1000,1000); dims=1)))...)
16.586 ms (2015 allocations: 23.09 MiB)
julia> @btime mapslices(x->sin.(x), rand(1000,1000); dims=1)
12.189 ms (7505 allocations: 23.18 MiB)
A bit off topic, but splatting probably impacts performance quite a bit, so probably it would b better to use mapreduce(x -> sin.(x), hcat, ...)
. For benchmarks you also want to use $(rand(1000, 1000))
(in that way the timings are unaffected by the calls of rand
).
A bit off topic, but splatting probably impacts performance quite a bit, so probably it would b better to use
mapreduce(x -> sin.(x), hcat, ...)
. For benchmarks you also want to use$(rand(1000, 1000))
(in that way the timings are unaffected by the calls ofrand
).
Thanks! Wasn't aware of this.
This however gave unexpected results for mapreduce
.
julia> @btime mapslices(x->sin.(x), $(rand(1000,1000)); dims=1);
10.581 ms (7503 allocations: 15.55 MiB)
julia> @btime mapreduce(x->sin.(x), hcat, eachslice($(rand(1000,1000)); dims=1));
914.970 ms (5002 allocations: 3.74 GiB)
Shouldn't you use eachslice(...; dims=2)
or eachcol
?
I don't think it is making much difference performance wise at least.
julia> @btime mapreduce(x->sin.(x), hcat, eachslice($(rand(1000,1000)); dims=2));
952.564 ms (5002 allocations: 3.74 GiB)
Can you check if the function is typestable? I suspect it might not, which would explain the number of allocations. The problem might be that it returns a different type if eachslice(...)
is empty. Specifying an init
kwarg might be helpful.
Just for the record -- we should be using ChainRulesCore to define pullbacks for Zygote
, and ChainRulesTestUtils to test those implementations -- see e.g. here for example usage.
Plans are in the works to transfer both Tracker
and ReverseDiff
over to use ChainRules
at some point (we know how we're going to do it, just waiting for code to get written), so this will future-proof AD in the package.
I can't figure out how to define mapslices
adjoint only in ChainRulesCore
. If f
is an anonymous function, how can we get its backward (rrule) from ChainRulesCore
, while it can be done in Zygote
using gradient
?
KernelFunctions doesn't use mapslices
anymore, so for this projects custom adjoints for mapslices
are not required anymore. Nevertheless, an implementation of an adjoint of mapslices
in ChainRules requires a solution to https://github.com/JuliaDiff/ChainRulesCore.jl/issues/68 AFAICT.
I see. I checked the source of the last release to backport TransformedKernel
to Stheno
as Stheno
doesn't support KernelFunctions
and found those mapslices
code. But after thinking how to implement it in ChainRules
or Zygote
, I just disable a check in Stheno
to re-enable gradient of f.(ColVecs(X))
since I feel ChainRules
or Zygote
will not too much difference.
The latest releases don't use mapslices
, it was replaced in https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/pull/152. I guess you can use something similar instead of mapslices
in Stheno as well.
Regarding FBMKernel not working with ForwardDiff
. It seems to be producing NaN
values incorrectly. According to the ForwardDiff
documentation, the fix for this is to "enable ForwardDiff's NaN-safe mode by setting the NANSAFE_MODE_ENABLED constant to true in ForwardDiff's source". They are currently not allowing users to enable it dynamically [Issue].
You can use https://github.com/JuliaDiff/ForwardDiff.jl/pull/451 if you do not want to edit the source code.
Hi--it just occurs to me to share this here, but I recently finished a project for computing derivatives of besselk
with respect to the order parameter precisely for the purpose of fitting Matern covariances (example Matern kernel implementation here). The strategy that worked ended up being a re-implementation of besselk
in Julia that admits fast and accurate AD derivatives with ForwardDiff.jl
. The re-implemented besselk
itself is not quite as accurate as the AMOS one linked in SpecialFunctions
, but the derivatives are pretty accurate. Not quite to machine double precision, but reasonably close. And very fast.
I'm not sure how helpful this is because the derivatives are at present pretty ForwardDiff
-specific. I would guess that it would be possible to reach compatibility with other AD tools, perhaps at a slight cost of performance by eliminating some special branches in the current implementation, but I honestly don't understand how Zygote works at all so I can't promise it.
Anyways, just writing here in case it is helpful.
I came across https://www.tandfonline.com/doi/pdf/10.1080/10652469.2016.1164156 a while ago, it contains closed-form expressions of the derivatives using e.g. hypergeometric functions. In principle these could be used with other AD backends as well but I don't know if there are any numerical problems, how slow/fast the evaluation with HypergeometricFunctions would be, and if (I assume not since it would introduce a circular dependency) SpecialFunctions would take a dependency on HypergeometricFunctions.
I also saw that paper and was interested in just using that before undertaking a more from-scratch approach. But there are a few challenges with using the representations in Santander. For one, as you point out, evaluating the generalized hypergeometric functions like 3F4 and 2F3 is a task of comparable difficulty. I love HyperGeometricFunctions.jl
, but that's a lot of pressure to put on that package, which at the very least in my experience is very slow when the besselk
argument is small (which is unfortunately where the accurate derivatives matter the most). More importantly, though, the representation in Santander is hard in a bunch of edge cases. Like, when $\nu$ is an integer or near-integer, there are several problems, both with cancellations and in trig functions blowing up. If nu = 1 + 1e-8
or something that ostensibly exact equation might give literally zero digits of accuracy. The exact derivatives when nu + 1/2
is a whole integer are particularly gnarly and I've never seen them for any case besides nu=1/2
.
Our project was enough of a hassle that we ended up writing a paper about it, and almost all the work was in handling the problems of $\nu$ being exactly or nearly an integer of half-integer. I don't think there's any way around a gnarly branching function to handle the derivatives in those cases. And if you look our timings (table one of the paper), it will probably be hard to come anywhere near those speeds at even comparable accuracy around those edge cases.
I've actually thought about asking the SpecialFunctions
package folks if they'd be interested in some of our code being added to that package, but considering that we are a bit cavalier in giving up the last couple digits of accuracy I'm a bit concerned that it's not a great fit.
In any case, just posting here for your consideration. If somebody manages to implement them with exact expressions in a way that is tolerably fast and handles those edge cases, I'll be the first person to celebrate. In the mean time, though, I wouldn't be shocked if zygote compatibility was possible. I just really don't know enough to conjecture about how much of a project it would be.
Here is a list of the failures in the tests made in #114 I observed with the different ADs : ForwardDiff.jl, Zygote.jl and ReverseDiff.jl :
This is a good starting point to try to find solutions