JuliaAstro / AstroLib.jl

Bundle of small astronomical and astrophysical routines.
http://juliaastro.github.io/AstroLib.jl/stable/
Other
78 stars 21 forks source link

Type of sunpos is too restrictive for use with ForwardDiff #73

Closed NOTtheMessiah closed 2 years ago

NOTtheMessiah commented 2 years ago

sunpos is intended "either a scalar or a vector" as per consistency with the IDL Astronomy User's Library, however the implementation of it does not take advantage of multiple dispatch, rather it relies on a single-dispatch wrapper over that loops over the values of the vector input, does explicit type conversions, and passes it onto a hidden method _sunpos. This manner of implementation, seemingly based directly on IDL, is not Julian, as it's overly restrictive and gets in the way of metaprogramming and interoperability with other libraries such as ForwardDiff and IntervalRootFinding, useful for differentiating and solving. ForwardDiff tries to generalize over Dual numbers and spits out this type error when it encounters a method defined only over floats: MethodError: no method matching _sunpos(::ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float64}, Float64, 1}, ::Bool) Closest candidates are: _sunpos(!Matched::AbstractFloat, ::Bool) at ~/.julia/packages/AstroLib/Yq5lS/src/sunpos.jl:4

A more Julian implementation would define two methods, one defined over reals, and the other defined on vectors of reals, and not force floats.

Aman-Pandey-afk commented 2 years ago

I want to help on this issue, but I feel I'm not getting the crux of the problem. We can use multiple dispatch for sunpos to take once a Real argument and once an AbstractArray of type real, and it gives the same output. (Don't have to float cast in both of them). (They are both using _sunpos to neatly handle the task, but it's still multiple dispatch). Or is Vector{AbstractFloat} absolutely needed as an argument, as it can store any type of floats. Also, can you provide info on how to get the ForwardDiff Error, as the part here shows you have used it's output as an argument to sunpos, which I don't get why we are calculating sunpos for derivatives).

giordano commented 2 years ago

The problem is basically that in AstroLib I used almost everywhere the pattern:

f(x::Real) = _f(float(x))
_f(x::AbstractFloat) = ...

to ensure type stability, but this breaks using types which aren't AbstractFloat, like ForwardDiff. Admittedly, this isn't a great pattern, but hey, this is the first Julia code I wrote (I don't think I used it elsewhere) :slightly_smiling_face:

The solution is to change the pattern to only

f(x::Real) = ....

but also make sure that the body of f is type-stable

giordano commented 2 years ago

PR #76 fixed the signature of sunpos, but it doesn't really solve the problem with ForwardDiff:

julia> ForwardDiff.derivative(sunpos, 1.0)
ERROR: MethodError: no method matching extract_derivative(::Type{ForwardDiff.Tag{typeof(sunpos), Float64}}, ::NTuple{4, ForwardDiff.Dual{ForwardDiff.Tag{typeof(sunpos), Float64}, Float64, 1}})
Closest candidates are:
  extract_derivative(::Type{T}, ::ForwardDiff.Dual) where T at ~/.julia/packages/ForwardDiff/PBzup/src/derivative.jl:81
  extract_derivative(::Type{T}, ::Real) where T at ~/.julia/packages/ForwardDiff/PBzup/src/derivative.jl:82
  extract_derivative(::Type{T}, ::AbstractArray) where T at ~/.julia/packages/ForwardDiff/PBzup/src/derivative.jl:83
Stacktrace:
 [1] derivative(f::typeof(sunpos), x::Float64)
   @ ForwardDiff ~/.julia/packages/ForwardDiff/PBzup/src/derivative.jl:14
 [2] top-level scope
   @ REPL[4]:1

This is because of a limitation of ForwardDiff.jl:

The types of array inputs must be subtypes of AbstractArray . Non-AbstractArray array-like types are not officially supported.

Honestly, I'm not going to change the return type of the function (changing from a non-allocating tuple to an allocating array) to accomodate limitations in another package, so I'm going to close this issue

giordano commented 2 years ago

If you really want, now you can manually turn the tuple into an array:

julia> ForwardDiff.derivative(x -> [sunpos(x)...], 1.0)
4-element Vector{Float64}:
  1.052789959609594
 -0.18784711921287695
  0.9948661788260018
 -1.0107962399285806e-6