jw3126 / Setfield.jl

Update deeply nested immutable structs.
Other
167 stars 17 forks source link

Type instability in `setindex` for lenses with slices #175

Closed phipsgabler closed 2 years ago

phipsgabler commented 2 years ago
julia> μ = rand(2,2)
2×2 Matrix{Float64}:
 0.943631   0.45465
 0.0845643  0.781602

julia> Setfield.set(μ, @lens(_[1, :]), [0.0, 0.0])
2×2 Matrix{Any}:
 0.0        0.0
 0.0845643  0.781602
phipsgabler commented 2 years ago

Shouldn't that be an eltype(v) rather than typeof(v) in setindex?

julia> @code_warntype Setfield.setindex(μ, [0.0, 0.0], 1, 1:2)
MethodInstance for Setfield.setindex(::Matrix{Float64}, ::Vector{Float64}, ::Int64, ::UnitRange{Int64})
  from setindex(xs::AbstractArray, v, I...) in Setfield at /home/pgabler/.julia/dev/Setfield/src/setindex.jl:5
Arguments
  #self#::Core.Const(Setfield.setindex)
  xs::Matrix{Float64}
  v::Vector{Float64}
  I::Tuple{Int64, UnitRange{Int64}}
Locals
  ys::Matrix{Any}
  T::Type{Any}
Body::Matrix{Any}
1 ─       nothing
│   %2  = Setfield.eltype(xs)::Core.Const(Float64)
│   %3  = Setfield.typeof(v)::Core.Const(Vector{Float64})
│         (T = Setfield.promote_type(%2, %3))
│         (ys = Setfield.similar(xs, T::Core.Const(Any)))
│   %6  = Setfield.eltype(xs)::Core.Const(Float64)
│   %7  = Core.apply_type(Setfield.Union)::Core.Const(Union{})
│   %8  = (%6 !== %7)::Core.Const(true)
└──       goto #3 if not %8
2 ─       Setfield.copy!(ys, xs)
3 ┄ %11 = I::Tuple{Int64, UnitRange{Int64}}
│   %12 = Core.tuple(ys, v)::Tuple{Matrix{Any}, Vector{Float64}}
│         Core._apply_iterate(Base.iterate, Base.setindex!, %12, %11)
└──       return ys
jw3126 commented 2 years ago

Thanks, you are right, eltype is a better alternative! With setindex I usually have scalar assignments in mind, there might be more gotchas like this... Anyway, a PR would be appreciated!

phipsgabler commented 2 years ago

Thinking about it again, I think eltype is the wrong choice, too -- it would then fail for normal assignment of non-numbers that are now OK, like

julia> Setfield.set([Some(1)], @lens(_[begin]), Some(2))
1-element Vector{Some{Int64}}:
 2

julia> eltype(Some(2))
Any

(eltype(Int) is Int!)

What we really should do instead is follow the logic by which indexing is calculated, and by that figure out whether the RHS is a scalar for a single element or an iterable corresponding to a slice.

phipsgabler commented 2 years ago

What about this:

julia> @generated function select_type(v_type::Type{Tv}, v_eltype::Type{Tve}, I::Tuple) where {Tv, Tve}
           if any(ti <: AbstractArray for ti in I.parameters)
               return Tve
           else
               return Tv
           end
       end
select_type (generic function with 2 methods)

julia> function setindex(xs::AbstractArray, v, I...)
           T = promote_type(eltype(xs), select_type(typeof(v), eltype(v), Base.to_indices(xs, I)))
           ys = similar(xs, T)
           if eltype(xs) !== Union{}
               copy!(ys, xs)
           end
           ys[I...] = v
           return ys
       end
setindex (generic function with 1 method)

julia> @code_warntype setindex([rand(2) for i in 1:2, j in 1:2], zeros(2), 1, 1)
MethodInstance for setindex(::Matrix{Vector{Float64}}, ::Vector{Float64}, ::Int64, ::Int64)
  from setindex(xs::AbstractArray, v, I...) in Main at REPL[32]:1
Arguments
  #self#::Core.Const(setindex)
  xs::Matrix{Vector{Float64}}
  v::Vector{Float64}
  I::Tuple{Int64, Int64}
Locals
  ys::Matrix{Vector{Float64}}
  T::Type{Vector{Float64}}
Body::Matrix{Vector{Float64}}
1 ─ %1  = Main.eltype(xs)::Core.Const(Vector{Float64})
│   %2  = Main.typeof(v)::Core.Const(Vector{Float64})
│   %3  = Main.eltype(v)::Core.Const(Float64)
│   %4  = Base.to_indices::Core.Const(to_indices)
│   %5  = (%4)(xs, I)::Tuple{Int64, Int64}
│   %6  = Main.select_type(%2, %3, %5)::Core.Const(Vector{Float64})
│         (T = Main.promote_type(%1, %6))
│         (ys = Main.similar(xs, T::Core.Const(Vector{Float64})))
│   %9  = Main.eltype(xs)::Core.Const(Vector{Float64})
│   %10 = Core.apply_type(Main.Union)::Core.Const(Union{})
│   %11 = (%9 !== %10)::Core.Const(true)
└──       goto #3 if not %11
2 ─       Main.copy!(ys, xs)
3 ┄ %14 = I::Tuple{Int64, Int64}
│   %15 = Core.tuple(ys, v)::Tuple{Matrix{Vector{Float64}}, Vector{Float64}}
│         Core._apply_iterate(Base.iterate, Base.setindex!, %15, %14)
└──       return ys

julia> @code_warntype setindex(rand(2,2), zeros(2), 1, :)
MethodInstance for setindex(::Matrix{Float64}, ::Vector{Float64}, ::Int64, ::Colon)
  from setindex(xs::AbstractArray, v, I...) in Main at REPL[32]:1
Arguments
  #self#::Core.Const(setindex)
  xs::Matrix{Float64}
  v::Vector{Float64}
  I::Tuple{Int64, Colon}
Locals
  ys::Matrix{Float64}
  T::Type{Float64}
Body::Matrix{Float64}
1 ─ %1  = Main.eltype(xs)::Core.Const(Float64)
│   %2  = Main.typeof(v)::Core.Const(Vector{Float64})
│   %3  = Main.eltype(v)::Core.Const(Float64)
│   %4  = Base.to_indices::Core.Const(to_indices)
│   %5  = (%4)(xs, I)::Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}
│   %6  = Main.select_type(%2, %3, %5)::Core.Const(Float64)
│         (T = Main.promote_type(%1, %6))
│         (ys = Main.similar(xs, T::Core.Const(Float64)))
│   %9  = Main.eltype(xs)::Core.Const(Float64)
│   %10 = Core.apply_type(Main.Union)::Core.Const(Union{})
│   %11 = (%9 !== %10)::Core.Const(true)
└──       goto #3 if not %11
2 ─       Main.copy!(ys, xs)
3 ┄ %14 = I::Tuple{Int64, Colon}
│   %15 = Core.tuple(ys, v)::Tuple{Matrix{Float64}, Vector{Float64}}
│         Core._apply_iterate(Base.iterate, Base.setindex!, %15, %14)
└──       return ys
jw3126 commented 2 years ago

Yeah that looks good. This is inspired by the Base.setindex! dispatch pipeline right? I think the @generated could be dropped by lispy recursion.

phipsgabler commented 2 years ago

Not that much, I just figured that the minimum path to figure out whether we are assigning a slice is to check the "normalized" form from to_indices in that way.

We can do a recursion, yes:

IndexTail = Vararg{Union{Integer, AbstractArray}}

select_type(v_type::Type{Tv}, v_eltype::Type{Tve}, I::Tuple{}) where {Tv, Tve} = Tv
select_type(v_type::Type{Tv}, v_eltype::Type{Tve}, I::Tuple{<:AbstractArray, IndexTail}) where {Tv, Tve} = Tve
select_type(v_type::Type{Tv}, v_eltype::Type{Tve}, I::Tuple{<:Integer, IndexTail}) where {Tv, Tve} =
    select_type(Tv, Tve, Base.tail(I))

Although in this case I think the semantics is clearer if written as a generated function.

I'm going to ask for some additional feedback on Zulip before making a PR.

jw3126 commented 2 years ago

This is solved @phipsgabler right? Otherwise please reopen.