JuliaDiff / FiniteDifferences.jl

High accuracy derivatives, estimated via numerical finite differences (formerly FDM.jl)
MIT License
298 stars 25 forks source link

j′vp and jacobian are not type-stable #199

Closed sethaxen closed 2 years ago

sethaxen commented 2 years ago

While jvp is. Example:

julia> using FiniteDifferences, Test

julia> x, ẋ = randn(3), randn(3);

julia> fdm = central_fdm(5, 1);

julia> @inferred fdm(sin, x[1]); # fine

julia> @inferred jvp(fdm, sum, (x, ẋ)); # fine

julia> @inferred j′vp(fdm, sum, 1.0, x); # uh-oh
ERROR: return type Tuple{Vector{Float64}} does not match inferred return type Tuple{Any}
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:33
 [2] top-level scope
   @ REPL[39]:1

julia> @inferred jacobian(fdm, identity, x); # uh-oh
ERROR: return type Tuple{Matrix{Float64}} does not match inferred return type Tuple{Union{Matrix, Vector{Any}}}
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:33
 [2] top-level scope
   @ REPL[45]:1
devmotion commented 2 years ago

Seems the jacobian case can be fixed by changing https://github.com/JuliaDiff/FiniteDifferences.jl/blob/178cfb59ca77aa0cf8aa760bb053274c67de8399/src/grad.jl#L24 to (reduce(hcat, ys),) which might also be more performant since it takes the fast path in base and avoids splatting. I can prepare a PR.

sethaxen commented 2 years ago

Similarly, it looks like the j′vp issue can be replaced by changing https://github.com/JuliaDiff/FiniteDifferences.jl/blob/56e1d6338bbe10bc3de6e4daa17c7d93d6a10e15/src/grad.jl#L73 to

    return (vec_to_x(_j′vp(fdm, x -> first(to_vec(f(vec_to_x(x)))), ȳ_vec, x_vec)), )

It seems broadcasting CompositeFunctions (at least in this case), is not type-stable on SVector inputs once they get larger than 2 functions:

julia> @inferred FiniteDifferences._eval_function(fdm.bound_estimator, identity, 1.0, 1.0);

julia> @inferred FiniteDifferences._eval_function(fdm.bound_estimator, identity ∘ identity, 1.0, 1.0);

julia> @inferred FiniteDifferences._eval_function(fdm.bound_estimator, identity ∘ identity ∘ identity, 1.0, 1.0);
ERROR: return type StaticArrays.SVector{7, Float64} does not match inferred return type StaticArrays.SVector{7}