JuliaGaussianProcesses / KernelFunctions.jl

Julia package for kernel functions for machine learning
https://juliagaussianprocesses.github.io/KernelFunctions.jl/stable/
MIT License
266 stars 32 forks source link

Issue with Distances.jl #380

Open theogf opened 3 years ago

theogf commented 3 years ago

We are relying a lot on Distances.jl however, as noted in #98, not all our pairwise operations are proper metrics. The PR #194 is stalling because defining everything again for DotProduct is quite a mess. @devmotion mentioned in https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/pull/378#discussion_r725001625 that we mostly don't need Distances.jl.

I am not really sure on how this should go!

willtebbutt commented 3 years ago

One option would be to entirely replace Distances with Tullio.jl. I've found it's very performant in practice, and it's quite a bit more flexible 🤷 . It has GPU support, so that would also be a win.

theogf commented 3 years ago

Uuuh! I like it! So we create our own implementation of "pairwise" (like binary_op) which we can just defer to Distances.pairwise for Distances.PreMetric?

willtebbutt commented 3 years ago

I think so? I'm not completely sure what the resulting implementation would look like -- it might be that we literally just have our own implementation of pairwise for various things. It might be that we do something more clever in a number of cases -- e.g. implement kernelmatrix etc explicitly for SEKernel in which the distance calculations are fused with the kappa function, or whatever. There's probably something really interesting to do here that I've not thought of yet -- it's so weird having a package like Tullio, because it means we can start to move away from having to start every implementation discussion with "how do I implement this in terms of gemm?"

devmotion commented 3 years ago

But we already have this in place (implementing our own pairwise), just not with Tullio and unnecessarily restricted to ::PreMetric? E.g., for Delta we already use broadcast(::Delta, x, y) if x and y are general vectors in our implementation of KernelFunctions.pairwise. I think one of the problems with #194 is that it tries to copy the internal functions such and dispatches of Distances for the AbstractBinaryOp (or whatever more general we want to define).

theogf commented 3 years ago

Yeah, in #194 I definitely went in the wrong direction trying to force-use the Distances.jl API. So the idea would be that we keep our KernelFunctions.pairwise function which by default rely on Distances.jl but we could have our own generic implementation using Tullio?

theogf commented 3 years ago

Oh wow! I just tried Tullio on CPU and even on SqEuclidean it's already a 3x speed up :scream: I will make a small benchmark report

theogf commented 3 years ago

More updates, GPU is working with Tullio :) and is around 250x faster on my machine! However, if the input is structured as Vector{Vector} the GPU operations are failing completely (which is relatively expected) You can find the script and the CPU benchmarks here: https://gist.github.com/theogf/7ed2bec68917283c02ce01dd14382ef6

willtebbutt commented 3 years ago

Nice, thanks for producing these. Is the correct interpretation that Tullio (for the sizes you picked) is roughly 50% slower than the implementations using gemm?

willtebbutt commented 3 years ago

More updates, GPU is working with Tullio :) and is around 250x faster on my machine!

Amazing. Even if we didn't want to adopt for CPU, this might be a great way to go for GPU.

theogf commented 3 years ago

Oh! I actually forgot something! Somehow I could not get a nice std implementation for std_pairwise(::SqEuclidean, etc...) I have no idea why the tests are failing. But I think the cost should be correct anyway :laughing: In general Tullio is faster than Distances.jl, and against gemm yes, it's 2 slower but allocates 3 times less too

devmotion commented 3 years ago

More updates, GPU is working with Tullio :) and is around 250x faster on my machine!

I guess this is not surprising as Distances uses scalar indexing which is terrible on GPUs (see https://github.com/JuliaStats/Distances.jl/issues/143 and https://github.com/JuliaStats/Distances.jl/issues/137).

theogf commented 3 years ago

I guess this is not surprising

I meant that the GPU computations are 250x faster than the CPU ones (also using Tullio). Something like 7.3ms against 29μs