JuliaDiff / ChainRulesCore.jl

AD-backend agnostic system defining custom forward and reverse mode rules. This is the light weight core to allow you to define rules for your functions in your packages, without depending on any particular AD system.
Other
257 stars 62 forks source link

Projection for `x:: AbstractRange` #437

Open mcabbott opened 3 years ago

mcabbott commented 3 years ago

If we want something like https://github.com/JuliaArrays/FillArrays.jl/pull/153 to project the gradient of a Fill onto a one-dimensional subspace, then I think we probably want something similar for the gradient of a range, but projecting onto a two-dimensional space, parameterised by the endpoints. Before I lose the bit of scrap paper I wrote this on, I think this would look as follows:

ProjectTo(x::AbstractRange) = ProjectTo{AbstractRange}()

function (project::ProjectTo{AbstractRange})(dx::AbstractVector)
    L = length(dx)
    μ = mean(dx)
    # δ = -sum(diff(dx))/2
    δ = sum(Base.splat(-), zip(dx, @view dx[2:end]))/2
    return LinRange(μ + δ, μ - δ, L)
end

(project::ProjectTo{AbstractRange})(dx::AbstractRange) = dx

Using LinRange allows for zero slope (e.g. for constant dx) and skips the high-precision machinery which StepRangeLen uses to hit endpoints exactly, as I don't think we're concerned about the last digit here. This isn't yet careful about element types etc.

mcabbott commented 3 years ago

The above formula is wrong. Correct versions are here: https://github.com/mcabbott/OddArrays.jl/blob/6c3ef3ab5ebf05c8aa6aa030590456200715be0f/src/OddArrays.jl#L814-L838

And the motivation is things like this:

julia> gradient(x -> (2 .* x)[1], 0:0.2:1)  # natural
([2.0, 0.0, 0.0, 0.0, 0.0, 0.0],)

julia> gradient(first, LinRange(0,1,5))  # structural
((start = 1.0, stop = nothing, len = nothing, lendiv = nothing),)

julia> gradient(x -> first(2 .* x), LinRange(0,1,5))
ERROR: DimensionMismatch("x and y are of different lengths!")
Stacktrace:
  [1] dot(x::Tangent{Any, NamedTuple{(:start, :stop, :len, :lendiv), Tuple{Float64, ZeroTangent, ZeroTangent, ZeroTangent}}}, y::LinRange{Float64, Int64})

julia> gradient(x -> first(LinRange(x,1,5)), 0)
(1.0,)

julia> gradient(x -> (2 .* LinRange(x,1,5))[1], 0)
ERROR: Need an adjoint for constructor LinRange{Float64, Int64}. Gradient is of type Vector{Float64}