sisl / GridInterpolations.jl

Multidimensional grid interpolation in arbitrary dimensions
Other
52 stars 12 forks source link

SVector interpolants #44

Closed zsunberg closed 8 months ago

zsunberg commented 8 months ago

This is an alternative to #43

It actually makes a significant change. Previously, the memory for the interpolants was cached in the grid object. This creates new static arrays with every call to interpolants.

Conceptually, this is a big improvement because previously, if you called interpolants, stored the results without copying, and then called interpolants again, it would overwrite the original interpolants! yikes!

Here are some new results:

using GridInterpolations
using BenchmarkTools

grid_data = [8.0, 1.0, 6.0, 3.0, 5.0, 7.0, 4.0, 9.0, 2.0]
x = [0.25, 0.75]

grid = RectangleGrid([0.0, 0.5, 1.0], [0.0, 0.5, 1.0]) # Float64
@btime interpolate($grid, $grid_data, $x)
#   33.782 ns (0 allocations: 0 bytes)

using ForwardDiff
f(x::Vector) = interpolate(grid, grid_data, x)
@btime ForwardDiff.gradient($f, $x)
#  635.357 ns (5 allocations: 320 bytes)

Unfortunately, there is a small regression in performance from the old way in some of the benchmarks from the tests

Old:

1000 interpolations of 6 dimensions with 15 cut points per dimension:
  Rectangle required 0.0009157518239999992 +/- 0.00023133335979010877 sec
  Simplex   required 0.0006389991239999997 +/- 0.0006475247064044843 sec
How large is the simplex grid speed up over the multilinear grid?
  limiting to 2 dimensions and therefore 316 points per dim:
    mean speed: 0.00031382153333333343, std dev: 6.616835029389875e-5
  limiting to 3 dimensions and therefore 46 points per dim:
    mean speed: 0.000284075, std dev: 6.0499168543730904e-5
  limiting to 4 dimensions and therefore 18 points per dim:
    mean speed: 0.0004026574000000001, std dev: 8.200233409226417e-5
  limiting to 5 dimensions and therefore 10 points per dim:
    mean speed: 0.0004601372666666666, std dev: 8.407939234160375e-5
100 interpolations of 4 dimensions with 10 cut points per dimension:
  Rectangle required 0.00022583938999999994 +/- 4.003089883334692e-5 sec
  Simplex   required 0.0003840863899999998 +/- 5.763477849382195e-5 sec

New:

1000 interpolations of 6 dimensions with 15 cut points per dimension:
  Rectangle required 0.0010046940880000007 +/- 0.0008021641712052207 sec
  Simplex   required 0.0005958515850000009 +/- 0.0007433027770190067 sec
How large is the simplex grid speed up over the multilinear grid?
  limiting to 2 dimensions and therefore 316 points per dim:
    mean speed: 0.0003624577666666666, std dev: 4.106958632297421e-5
  limiting to 3 dimensions and therefore 46 points per dim:
    mean speed: 0.0002571071666666667, std dev: 5.565465236576383e-5
  limiting to 4 dimensions and therefore 18 points per dim:
    mean speed: 0.0003611892, std dev: 6.304453687157774e-5
  limiting to 5 dimensions and therefore 10 points per dim:
    mean speed: 0.00041827066666666663, std dev: 5.4756525932804423e-5
100 interpolations of 4 dimensions with 10 cut points per dimension:
  Rectangle required 0.00032039601999999995 +/- 0.0009210866573612322 sec
  Simplex   required 0.00033461859999999995 +/- 4.617121162356343e-5 sec

I think the new safer and easier-to-understand code is well worth it.

mossr commented 8 months ago

I'm not a fan of the $ interpolation as a requirement. Could you explain that choice?

zsunberg commented 8 months ago

I'm not a fan of the $ interpolation as a requirement. Could you explain that choice?

That's just for benchmarking: https://juliaci.github.io/BenchmarkTools.jl/stable/manual/#Interpolating-values-into-benchmark-expressions

mossr commented 8 months ago

Ah that's great. So I'm guessing it will infer the type based on x here? (Instead of my implementation with an explicit type input)

zsunberg commented 8 months ago

Ah that's great. So I'm guessing it will infer the type based on x here? (Instead of my implementation with an explicit type input)

Exactly.

One downside is that this would stop working for RectangleGrid interpolations in more than 16 dimensions:

julia> @MVector(zeros(Int, 2^17))
Internal error: encountered unexpected error during compilation of zeros:
StackOverflowError()

Probably people should be using simplex interpolation by that point though...

mossr commented 8 months ago

I think that's reasonable and should just be mentioned in the README to warn users (and offer a suggestion to use the SimplexGrid instead)

zsunberg commented 8 months ago

@mossr , should I assume you are good with this? It sounds like you are from the comments, but you didn't submit an official review.

mossr commented 8 months ago

I pulled and tested this locally and I get the following errors when trying Complex or Real typed x vectors:

grid = RectangleGrid([0.0, 0.5, 1.0], [0.0, 0.5, 1.0])
grid_data = [8.0, 1.0, 6.0, 3.0, 5.0, 7.0, 4.0, 9.0, 2.0]
x = Complex[0.25, 0.75]
interpolate(grid, grid_data, x)

ERROR: setindex!() with non-isbitstype eltype is not supported by StaticArrays. Consider using SizedArray.
Stacktrace:
 [1] error(s::String)
   @ Base .\error.jl:35
 [2] setindex!
   @ C:\Users\RobertMoss\.julia\packages\StaticArrays\cZ1ET\src\MArray.jl:39 [inlined]
 [3] interpolants(grid::RectangleGrid{2}, x::Vector{Complex})
   @ GridInterpolations C:\Users\RobertMoss\.julia\packages\GridInterpolations\CSHS7\src\GridInterpolations.jl:168
 [4] interpolate(grid::RectangleGrid{2}, data::Vector{Float64}, x::Vector{Complex})
   @ GridInterpolations C:\Users\RobertMoss\.julia\packages\GridInterpolations\CSHS7\src\GridInterpolations.jl:145
 [5] top-level scope
   @ REPL[17]:1

Do you see this as well?

zsunberg commented 8 months ago

Yeah, I realized that I had a typo above which meant that I had not actually tested the Complex version.

One problem with what you tried is that Complex is an abstract type. If you try Complex{Float64} you will get past that error. But then you will run into the error that you cannot compare complex with floating point scalar. In the end, it just doesn't make sense to interpolate complex numbers onto a floating point grid.

Really this PR just makes things differentiable and (imo just as importantly) eliminates returning references to memory that will be silently re-written later. It doesn't add the capability to accept more exotic arguments (e.g. complex). That would require more work.

(I just edited and removed the Complex example from the PR description since it is not meaningful)

mossr commented 8 months ago

I'm good with that. The purpose of my PR #43 was solely to get autodiff working, and I tested the ForwardDiff.gradient(f, x) example and it works.

Looks good to me!