SciML / LabelledArrays.jl

Arrays which also have a label for each element for easy scientific machine learning (SciML)
https://docs.sciml.ai/LabelledArrays/stable/
Other
120 stars 21 forks source link

Rewrite `SLVector` as a subtype of `StaticVector` #19

Closed MSeeker1340 closed 5 years ago

MSeeker1340 commented 6 years ago

https://github.com/JuliaDiffEq/LabelledArrays.jl/pull/18#issuecomment-436278719

For the @SLVector macro, I modified @YingboMa's implementation to not return an anonymous constructor-like function but instead the type/constructor itself. This makes code much simpler (e.g. see the new slvectors.jl tests).

similar(x) and AbstractFloat{T}(x) still returns an unwrapped MArray. This is the behavior intended by StaticArrays (https://github.com/JuliaArrays/StaticArrays.jl/blob/master/src/abstractarray.jl#L94). I can probably make it better by defining a MLVector type, but for now this isn't really an issue.

ChrisRackauckas commented 6 years ago

Looks good to me. Get a PR ready for PuMaS with these changes to make sure it does what we need. This should get merged and tagged with a minor release.

MSeeker1340 commented 5 years ago

@ChrisRackauckas Need some help with metaprogramming.

In the PuMaS update, I tried to use the new @SLVector as

# odevars == (:Depot, :Central)
uType = @SLVector Float64 :($(odevars...,))

However this is what I got:

julia> odevars = (:Depot, :Central);

julia> @SLVector Float64 (:Depot, :Central)
SLVector{2,Float64,(:Depot, :Central)}

julia> @SLVector Float64 :($(odevars...,))
SLVector{1,Float64,(:Depot, :Central)}

The macro definition is

macro SLVector(E,syms)
    quote
        SLVector{$(length(syms.args)),$(esc(E)),$syms}
    end
end

I can of course just use plain constructors instead of @SLVector, but I'm curious as to why I got this behavior.

ChrisRackauckas commented 5 years ago

Yeah I'm not sure why that happens, but I noticed it before...

ChrisRackauckas commented 5 years ago

Before tagging, I want to see if we can get this working on arrays and not just vectors via whatever the new ind2sub is.

YingboMa commented 5 years ago
julia> struct SLVector{A,B,C} end

julia> macro SLVector(E,syms)
           n = syms isa Expr ? length(syms.args) : length(syms)
           quote
               SLVector{$n,$(esc(E)),$(esc(syms))}
           end
       end
@SLVector (macro with 1 method)

julia> odevars = (:Depot, :Central);

julia> @SLVector Float64 (:Depot, :Central)
SLVector{2,Float64,(:Depot, :Central)}

julia> @eval @SLVector Float64 $(odevars...,)
SLVector{2,Float64,(:Depot, :Central)}
YingboMa commented 5 years ago

Maybe we shouldn't use Val to do index at all. That is not a fast way to do it.

julia> using BenchmarkTools, LabelledArrays

julia> ABC = @SLVector Int (:a,:b,:c)
SLVector{3,Int64,(:a, :b, :c)}

julia> b = ABC(1,2,3)
3-element SLVector{3,Int64,(:a, :b, :c)}:
 1
 2
 3

julia> @btime b[i] setup=(i = (:a, :b, :c)[rand(1:3)])
  3.138 μs (1 allocation: 32 bytes)
1

julia> function Base.getindex(x::SLVector,s::Symbol)
         idx = findfirst(isequal(s), LabelledArrays.symnames(typeof(x)))
         getfield(x, :__x)[idx]
       end

julia> @btime b[i] setup=(i = (:a, :b, :c)[rand(1:3)])
  14.571 ns (0 allocations: 0 bytes)
2

julia> @btime b[i] setup=(i = rand(1:3))
  15.217 ns (0 allocations: 0 bytes)
2

The naive implementation is fast.

ChrisRackauckas commented 5 years ago

what's the generated code like? When I tried something like that, constant prop didn't work through findfirst. The generated function makes sure it compiles away.

YingboMa commented 5 years ago
julia> @btime b[i] setup=(i = (:a, :b, :c)[rand(1:3)])
  2.828 μs (1 allocation: 32 bytes)
1

julia> @btime b[:a]
  2.985 μs (1 allocation: 32 bytes)
1

julia> @btime b[Val(:a)]
  13.985 ns (0 allocations: 0 bytes)
1
ChrisRackauckas commented 5 years ago

You might be timing something odd in the global scope there? Interpolate it in?

This is a good case for checking the generated code though. It's either running findfirst or it's just using the scalar index at runtime. If it's not just compiling down to a scalar indexing, that would be an issue for larger SArray operations.

YingboMa commented 5 years ago

I don't need to interpolate when I do @btime b[i] setup=(i = (:a, :b, :c)[rand(1:3)]). With "Val indexing", I got

julia> @code_typed b[:a]
CodeInfo(
22 1 ─ %1 = invoke LabelledArrays.Val(_3::Symbol)::Val{_1} where _1                                    │
   │   %2 = (Base.getindex)(x, %1)::Any                                                                │
   └──      return %2                                                                                  │
) => Any

julia> goo(b) = b[:a]
goo (generic function with 1 method)

julia> @code_typed goo(b)
CodeInfo(
1 1 ─ %1 = (LabelledArrays.getfield)(b, :__x)::SArray{Tuple{3},Int64,1,3}        │╻╷╷╷ getindex
  │   %2 = (Base.getfield)(%1, :data)::Tuple{Int64,Int64,Int64}                  ││╻    getindex
  │   %3 = (Base.getfield)(%2, 1, true)::Int64                                   │││╻    macro expansion
  └──      return %3                                                             │
) => Int64

With the naive implementation, I got

julia> @code_typed b[:a]
CodeInfo(
2 1 ── %1  = (Base.getfield)((:a, :b, :c), 1, true)::Symbol                                                                                │╻╷╷  findfirst
  └───       goto #12 if not true                                                                                                          ││
  2 ┄─ %3  = φ (#1 => 1, #11 => %22)::Int64                                                                                                ││
  │    %4  = φ (#1 => %1, #11 => %23)::Symbol                                                                                              ││
  │    %5  = φ (#1 => 1, #11 => %24)::Int64                                                                                                ││
  │    %6  = (%4 === s)::Bool                                                                                                              ││╻╷╷  Fix2
  └───       goto #4 if not %6                                                                                                             ││
  3 ──       goto #13                                                                                                                      ││
  4 ── %9  = (%5 === 3)::Bool                                                                                                              │││╻╷   iterate
  └───       goto #6 if not %9                                                                                                             ││││
  5 ──       goto #7                                                                                                                       ││││
  6 ── %12 = (Base.add_int)(%5, 1)::Int64                                                                                                  ││││╻    +
  └───       goto #7                                                                                                                       │││╻    iterate
  7 ┄─ %14 = φ (#5 => true, #6 => false)::Bool                                                                                             │││
  │    %15 = φ (#6 => %12)::Int64                                                                                                          │││
  │    %16 = φ (#6 => %12)::Int64                                                                                                          │││
  │    %17 = φ (#5 => true)::Bool                                                                                                          │││
  └───       goto #9 if not %14                                                                                                            │││
  8 ──       goto #10                                                                                                                      │││
  9 ── %20 = (Base.getfield)((:a, :b, :c), %15, true)::Symbol                                                                              │││╻    getindex
  └───       goto #10                                                                                                                      ││╻    iterate
  10 ┄ %22 = φ (#9 => %15)::Int64                                                                                                          ││
  │    %23 = φ (#9 => %20)::Symbol                                                                                                         ││
  │    %24 = φ (#9 => %16)::Int64                                                                                                          ││
  │    %25 = φ (#8 => %17, #9 => false)::Bool                                                                                              ││
  │    %26 = (Base.not_int)(%25)::Bool                                                                                                     ││
  └───       goto #12 if not %26                                                                                                           ││
  11 ─       goto #2                                                                                                                       ││
  12 ┄ %29 = Base.nothing::Const(nothing, false)                                                                                           ││
  └───       goto #13                                                                                                                      ││
  13 ┄ %31 = φ (#3 => %3, #12 => %29)::Union{Nothing, Int64}                                                                               │
3 │    %32 = (Main.getfield)(x, :__x)::SArray{Tuple{3},Int64,1,3}                                                                          │
  │    %33 = (isa)(%31, Int64)::Bool                                                                                                       │
  └───       goto #15 if not %33                                                                                                           │
  14 ─ %35 = π (%31, Int64)                                                                                                                │
  │    %36 = (Base.getfield)(%32, :data)::Tuple{Int64,Int64,Int64}                                                                         ││╻    getproperty
  │    %37 = (Base.getfield)(%36, %35, true)::Int64                                                                                        ││╻    getindex
  └───       goto #18                                                                                                                      │
  15 ─ %39 = (isa)(%31, Nothing)::Bool                                                                                                     │
  └───       goto #17 if not %39                                                                                                           │
  16 ─ %41 = π (%31, Nothing)                                                                                                              │
  │          invoke Base.to_index(%32::SArray{Tuple{3},Int64,1,3}, %41::Nothing)::Union{}                                                  ││╻╷   to_indices
  │          $(Expr(:unreachable))::Union{}                                                                                                │││┃    to_indices
  │          φ ()::Union{}                                                                                                                 │││
  │          $(Expr(:unreachable))::Union{}                                                                                                │││
  │          φ ()::Union{}                                                                                                                 ││
  │          $(Expr(:unreachable))::Union{}                                                                                                ││
  └───       $(Expr(:unreachable))::Union{}                                                                                                │
  17 ┄       (Core.throw)(ErrorException("fatal error in type inference (type bound)"))::Union{}                                           │
  └───       $(Expr(:unreachable))::Union{}                                                                                                │
  18 ┄       return %37                                                                                                                    │
) => Int64

julia> goo(b) = b[:a]
goo (generic function with 1 method)

julia> @code_typed goo(b)
CodeInfo(
1 1 ─ %1 = (Main.getfield)(b, :__x)::SArray{Tuple{3},Int64,1,3}                                                                               │╻   getindex
  │   %2 = (Base.getfield)(%1, :data)::Tuple{Int64,Int64,Int64}                                                                               ││╻   getindex
  │   %3 = (Base.getfield)(%2, 1, true)::Int64                                                                                                │││╻   getindex
  └──      return %3                                                                                                                          │
) => Int64
ChrisRackauckas commented 5 years ago

So the naive implementation still does constant prop, it's just that it doesn't overdo the compilation when used from the global scope?

YingboMa commented 5 years ago

The naive implementation still compiles quite well in the global scope, but with Val, if the compiler cannot do constant prop, the performance is going to deplete.

ChrisRackauckas commented 5 years ago

Pick this up post https://github.com/JuliaDiffEq/LabelledArrays.jl/pull/20