JuliaMath / Interpolations.jl

Fast, continuous interpolation of discrete datasets in Julia
http://juliamath.github.io/Interpolations.jl/
Other
524 stars 110 forks source link

Splat arguments in gradient `rrule` #465

Closed mcabbott closed 2 years ago

mcabbott commented 2 years ago

I think the return type of the rrule defined here is incorrect. There ought to be one number per input, not a vector:

julia> itp(x,y)
0.7176519f0

julia> Interpolations.gradient(itp, x, y)
2-element StaticArrays.SVector{2, Float32} with indices SOneTo(2):
 -0.70559686
 -1.2178432

julia> Zygote.pullback(itp, x, y)[2](1f0)  # should be a tuple of 2 scalars
(Float32[-0.70559686, -1.2178432],)

julia> Zygote.gradient(itp, x, y)
ERROR: MethodError: no method matching (::ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}})(::StaticArrays.SVector{2, Float32})

Does Interpolations.gradient always return an SVector? If so, the splat here should be free. If not, and we might get a Vector, then it should probably be replaced.

Unrelated, but Interpolations.gradient doesn't compute the gradient with respect to the image etc, only the evaluation point. I think the right way to encode this is a NotImplemented object, which should give an error if you try to use it. Returning NoTangent() instead means that you will silently get zero. But I am not 100% confident this won't cause surprises. Cc @oxinabox who knows more?

Needs tests. Existing tests here https://github.com/JuliaMath/Interpolations.jl/blob/master/test/chainrules.jl just check that it matches Interpolations.gradient (and will fail). One possibility is to use ChainRulesTestUtils for this, which will do elaborate finite-diff tests. But I see there are also more extensive tests here https://github.com/JuliaMath/Interpolations.jl/blob/master/test/gradient.jl so perhaps that's not needed?

oxinabox commented 2 years ago

Interpolations.gradient doesn't compute the gradient with respect to the image etc, only the evaluation point. I think the right way to encode this is a NotImplemented object, which should give an error if you try to use it. Returning NoTangent() instead means that you will silently get zero. But I am not 100% confident this won't cause surprises. Cc @oxinabox who knows more?

You are correct. Shouldn't be any surprises. NotImplemented acts so much like am AbstractZero that I kinda sometimes think we should have given it that super type. But yes, it acts enough like it that basically anything that was using NoTangent in a safe way that wouldn't have lead to wrong answers is still going to work. And anything else was wrong before.

codecov[bot] commented 2 years ago

Codecov Report

Merging #465 (8fb478c) into master (d11ddd1) will increase coverage by 0.00%. The diff coverage is 100.00%.

Impacted file tree graph

@@           Coverage Diff           @@
##           master     #465   +/-   ##
=======================================
  Coverage   85.10%   85.11%           
=======================================
  Files          25       25           
  Lines        1746     1747    +1     
=======================================
+ Hits         1486     1487    +1     
  Misses        260      260           
Impacted Files Coverage Δ
src/chainrules/chainrules.jl 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update d11ddd1...8fb478c. Read the comment docs.

mkitti commented 2 years ago

cc: @rick2047