sisl / GridInterpolations.jl

Multidimensional grid interpolation in arbitrary dimensions
Other
52 stars 12 forks source link

Customize weight type to allow for autodiff packages to work #43

Closed mossr closed 8 months ago

mossr commented 8 months ago

In support of the working textbook Algorithms for Validation and to address issue #38, this PR allows a user to specify the weight type to allow for autodiff packages to work properly.

The default behavior is to use the existing vectors of Float64 (i.e., no compromise of efficiency). Yet, the user can specify the weight type as an input to the RectangleGrid and SimplexGrid constructors.

Here are some benchmarks:

using GridInterpolations
using BenchmarkTools

grid_data = [8.0, 1.0, 6.0, 3.0, 5.0, 7.0, 4.0, 9.0, 2.0]
x = [0.25, 0.75]

# default, unchanged behavior
grid = RectangleGrid([0.0, 0.5, 1.0], [0.0, 0.5, 1.0]) # Float64
@btime interpolate(grid, grid_data, x)
# 164.948 ns (2 allocations: 112 bytes)

grid = RectangleGrid(Real, [0.0, 0.5, 1.0], [0.0, 0.5, 1.0])
@btime interpolate(grid, grid_data, x)
# 689.262 ns (33 allocations: 608 bytes)

grid = RectangleGrid(Number, [0.0, 0.5, 1.0], [0.0, 0.5, 1.0])
@btime interpolate(grid, grid_data, x)
# 705.556 ns (33 allocations: 608 bytes)

With the grid with weight type Real or Number, the following now works:

using ForwardDiff
f(x::Vector) = interpolate(grid, grid_data, x)
ForwardDiff.gradient(f, x)

FYI, with the Float64 type you get the following ForwardDiff error trying to cast a ForwardDiff.Dual type (which is <: Real) to a Float64:

ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float64}, Float64, 2})

Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat
   @ Base rounding.jl:207
  (::Type{T})(::T) where T<:Number
   @ Core boot.jl:792
  (::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number}
   @ Base char.jl:50
  ...

Stacktrace:
  [1] convert(#unused#::Type{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float64}, Float64, 2})
    @ Base .\number.jl:7
  [2] setindex!(A::Vector{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float64}, Float64, 2}, i1::Int64)
    @ Base .\array.jl:969
  [3] interpolants(grid::RectangleGrid{2, Float64}, x::Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float64}, Float64, 2}})
    @ GridInterpolations C:\Users\RobertMoss\.julia\dev\GridInterpolations\src\GridInterpolations.jl:220
  [4] interpolate(grid::RectangleGrid{2, Float64}, data::Vector{Float64}, x::Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float64}, Float64, 2}})
    @ GridInterpolations C:\Users\RobertMoss\.julia\dev\GridInterpolations\src\GridInterpolations.jl:163
  [5] f(x::Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float64}, Float64, 2}})       
    @ Main c:\Users\RobertMoss\.julia\dev\GridInterpolations\test\benchmark.jl:21
  [6] vector_mode_dual_eval!
    @ C:\Users\RobertMoss\.julia\packages\ForwardDiff\PcZ48\src\apiutils.jl:24 [inlined]    
  [7] vector_mode_gradient(f::typeof(f), x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{typeof(f), Float64}, Float64, 2, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float64}, Float64, 2}}})
    @ ForwardDiff C:\Users\RobertMoss\.julia\packages\ForwardDiff\PcZ48\src\gradient.jl:89  
  [8] gradient(f::Function, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{typeof(f), Float64}, Float64, 2, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float64}, Float64, 2}}}, ::Val{true})
    @ ForwardDiff C:\Users\RobertMoss\.julia\packages\ForwardDiff\PcZ48\src\gradient.jl:19  
  [9] gradient(f::Function, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{typeof(f), Float64}, Float64, 2, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float64}, Float64, 2}}})
    @ ForwardDiff C:\Users\RobertMoss\.julia\packages\ForwardDiff\PcZ48\src\gradient.jl:17  
 [10] gradient(f::Function, x::Vector{Float64})
    @ ForwardDiff C:\Users\RobertMoss\.julia\packages\ForwardDiff\PcZ48\src\gradient.jl:17  
 [11] top-level scope
    @ c:\Users\RobertMoss\.julia\dev\GridInterpolations\test\benchmark.jl:22
zsunberg commented 8 months ago

I just submitted an alternative: #44

zsunberg commented 8 months ago

44 was merged instead of this.