haskell-numerics / hmatrix

Linear algebra and numerical computation
381 stars 104 forks source link

Numeric.LinearAlgebra.Static scalar vector multiplication type error #294

Closed ndemeyembosco closed 5 years ago

ndemeyembosco commented 5 years ago

In Numeric.LinearAlgebra.Static (hmatrix version 0.19.0.0), In ghci the following rightly typechecks 3.0 * (vec4 1 2 3 4) producing (vector, 9.0, 12.0] :: R 4

However, the following does not:

let a = (vec2 1 2) <.> (vec2 3 4) in 
let b = (vec4 2 3 4 5) <.> (vec4 6 7 8 2) in 
let c = a / b in 
(vec4 1 2 3 4) + (c * (vec4 3 4 5 6))

The error is as follows:

Couldn't match type 'Numeric.LinearAlgebra.Static.R 4' with Double'
Expected type: **R**
Actual type: Numeric.LinearAlgebra.Static.R 4
In the second argument of '(*)', namely 'vec4 3 4 5 6'
In the expression (vec4 1 2 3 4) + (c * (vec4 3 4 5 6))

Am I missing something, or is this indeed a bug in the typechecker?

mstksg commented 5 years ago

Remember that the type of * is:

(*) :: Num a => a -> a -> a

So, * is not the 'scalar multiplication operator'; it's actually implemented as component-wise multiplication of vectors:

<1, 2, 3> * <4, 5, 6> = <4, 10, 18>
<a, b, c> * <x, y, z> = <a*x, b*y, c*z>

So when you type

3.0 * vec4 1 2 3 4

remember that * expects both sides to be the same type. We know the type of vec4 1 2 3 4, it's R 4. So that must mean that 3.0 :: R 4, as well. Remember that in Haskell, numeric literals are all polymorphic. That's why we can write things like 3 * x, and 3 will be interpreted as a Int, Integer, Double, etc., depending on what x is. 3.0 is a polymorphic literal that is interpreted as whatever type is necessary to make the typechecker happy; it is not always Double, or Float, etc.

So the question is now, what does 3.0 :: R 4 even mean? How are we to define the interpretation of a numeric literal as a vector?

Well, one sensible law we would want to hold is:

1 * (x :: R 4) = x

From here, the only reasonable interpretation of 1 :: R 4 would be vec4 1 1 1 1, since:

1 * vec4 i j k l
  = vec4 1 1 1 1 * vec4 i j k l
  = vec4 (1 * i) (1 * j) (1 * k) (1 * l)
  = vec4 i j k l

So if we define our interpretation of 1 :: R 4 as vec4 1 1 1 1; so, we also define our interpretation of 3 :: R 4 as vec4 3 3 3 3.

So when you type

3.0 * vec4 1 2 3 4

It is interpreted as

vec4 3.0 3.0. 3.0 3.0 * vec4 1 2 3 4

Which gives:

vec4 3 6 9 12

Remember, you are never multiplying a double by anything. You're always multiplying the vector vec4 3 3 3 3.

So, what happens in the second case?

In the second case, we have c :: Double, and being used with * and vec4 3 4 5 6 :: R 4. This is a type error, as on one side you have a Double, and on the other side you have an R 4.

To make this work, you can use konst :: Double -> R n, and so use konst c * vec 3 4 5 6, which is vec4 c c c c * vec4 3 4 5 6, which is vec4 (c * 3) (c * 4) (c * 5) (c * 6).

ndemeyembosco commented 5 years ago

Alright this makes a lot of sense now. Thanks