JuliaArrays / StaticArrays.jl

Statically sized arrays for Julia
Other
761 stars 147 forks source link

Inconsistent output types of SMatrix solve with Vector/Matrix rhs #796

Open thchr opened 4 years ago

thchr commented 4 years ago

Consider the following example:

P = @SMatrix [-0.5 0.5 0.5; 0.5 -0.5 0.5; 0.5 0.5 -0.5] # SMatrix{3,3,Float64,9}
w = rand(Float64, 3)    # Vector{Float64}
W = rand(Float64, 3, 3) # Matrix{Float64}

println(typeof(P\w)) # => SArray{Tuple{3},Float64,1,3}
println(typeof(P\W)) # => Array{Float64,2}

i.e. doing a matrix solve with \(::SMatrix, ::Matrix) returns an ordinary matrix but a matrix solve with \(::SMatrix, ::Vector) returns a static vector. This seems surprising to me: I had expected both calls to return non-static/ordinary vectors/matrices (or at least the same "base" array type).

Here, \ calls into the generic (\)(A::AbstractMatrix, B::AbstractVecOrMat) from LinearAlgebra (at https://github.com/JuliaLang/julia/blob/cf410dc9e81cfa14a18ca3aef50346000b4615c7/stdlib/LinearAlgebra/src/generic.jl#L1102) which in turn calls into StaticArrays' \ method for LU decompositions at https://github.com/JuliaArrays/StaticArrays.jl/blob/64c64b2d808eed0b5751d080205ae2efe218f57e/src/lu.jl#L186-L187

So the issue seems to me that v[F.p] in the above converts v to an SVector (due to the indexing call), but B[F.p,:] is still an ordinary matrix due to the combination with :.

c42f commented 4 years ago

This is intentional: the static size of the second dimension of P\W cannot be inferred from the type of P. In principle it is something like Size(3,Dynamic()), which is a partially dynamic size which can't be represented in a StaticMatrix subtype. (For that you'd need a type from @mateuszbaran 's https://github.com/mateuszbaran/HybridArrays.jl or something similar.)

In contrast, the static size of P\w is always Size(3) (if not, an exception will occur).

Of course it would be possible to fall back to creating the output Size dynamically, but this would lead to fairly bad type instability and the performance would probably be quite bad.

This is a fairly fundamental problem which I'm not sure we can do much about (except perhaps document better).

thchr commented 4 years ago

I see, thanks for explaining this for me!

I see now that this is also the same for e.g. multiplication of SMatrix with Vector/Matrix. I had naively thought the rule would be to "promote down" to the type with less structural information, i.e. to Vector/Matrix.

Coincidentally, after reading your explanation, I realized one can do this:

P = @SMatrix [-0.5 0.5 0.5; 0.5 -0.5 0.5; 0.5 0.5 -0.5]
F = lu(P)

F\rand(500)      # => 3-element SArray{Tuple{3},Float64,1,3}
F\rand(500, 500) # => 3×500 Array{Float64,2}

which probably ought to throw a DimensionMismatch error?

c42f commented 4 years ago
F\rand(500)      # => 3-element SArray{Tuple{3},Float64,1,3}

Yikes, that's certainly a bug, thanks for noting that. Seems likely that indexing with F.p hides this bug at https://github.com/JuliaArrays/StaticArrays.jl/blob/64c64b2d808eed0b5751d080205ae2efe218f57e/src/lu.jl#L186

thchr commented 4 years ago

For what it's worth, I still find it somewhat surprising that e.g. the product/solve of an SMatrix with an MVector yields an MVector, i.e. mutable container, but the product/solve with an ordinary Vector yields an SVector, i.e. an immutable container.

It's consistent in the sense that all of them yield sized vectors, but the mutability point can be a bit of a gotcha, I think.

I made a small check to see what the behavior is, also including addition; I thought some of the mutual behaviors were rather surprising:

# Multiplication (or division) of mixed types
*(::Matrix,  ::MVector) = ::Vector
*(::Matrix,  ::SVector) = ::Vector
*(::MMatrix, ::Vector)  = ::SVector # !
*(::MMatrix, ::SVector) = ::SVector
*(::SMatrix, ::Vector)  = ::SVector
*(::SMatrix, ::MVector) = ::MVector

# Addition (subtraction) of mixed types
+(::Vector,  ::MVector) = ::SVector # !
+(::MVector, ::Vector)  = ::MVector
+(::MVector, ::SVector) = ::MVector
+(::SVector, ::MVector) = ::SVector