Mikolaj / horde-ad

Higher Order Reverse Derivatives Efficiently - Automatic Differentiation library based on the paper "Provably correct, asymptotically efficient, higher-order reverse-mode automatic differentiation"
BSD 3-Clause "New" or "Revised" License
33 stars 6 forks source link

Implement matrix multiplication generalized to tensors #69

Open Mikolaj opened 2 years ago

Mikolaj commented 2 years ago

This is a continuation of https://github.com/Mikolaj/horde-ad/issues/64#issuecomment-1250029728.

The idea is to implement

dot :: (ADModeAndNum d r, OS.Shape sh, KnownNat n1, KnownNat n2)
    => ADVal d (OS.Array '[n1, n2] r)
    -> ADVal d (OS.Array (n2 ': sh) r)
    -> ADVal d (OS.Array (n1 ': sh) r)
dot m t = ...

where the extra dimensions in sh behave as in mini-batches, that is, ordinary matrix multiplication is performed for each array contained within the extra dimensions and the results are embedded in the extra dimensions analogously. E.g., if we have one extra dimension of size 3, then matrix multiplication would be performed three times and we'd get a tensor corresponding to a 3-element vector the resulting matrices. We'd need @awf to confirm this is the generalization (and the name) that makes the most sense.

We already have

(<>$) :: (ADModeAndNum d r, KnownNat m, KnownNat n, KnownNat p)
      => ADVal d (OS.Array '[m, n] r)
      -> ADVal d (OS.Array '[n, p] r)
      -> ADVal d (OS.Array '[m, p] r)
(<>$) d e = from2S $ fromS2 d <>! fromS2 e

which should be generalized (and (<>$) most probably used in the generalization, though the (<>!) operation from hmatrix may be used directly as well).

Sadly, an elegant solution with the m argument enlarged using broadcast (see https://hackage.haskell.org/package/orthotope/docs/Data-Array-ShapedS.html) and then, after some uses of transpose, matrix multiplication performed in all inner matrices using rerank2, doesn't work. That's because rerank2 requires both its argument tensors to be of exactly equal shape, while arguments to matrix multiplication don't need to have an equal shape. Generalizing rerank2 seems hard, given that we don't even have a dual number counterpart of rerank2 yet (to be done, if at all possible, in #28).

An plausible solution is to turn the t tensor into a list of lists of lists of matrices (the unravel operation), perform the multiplications and roll the lists back into a tensor (the ravel operation). This is going to be tricky to type, because the length of sh is arbitrary, but at least we already have dual number counterparts of unravel and ravel.

awf commented 2 years ago

The best name imo is matMul. One sometimes sees groupedMatMul, but it's not needed.

Mikolaj commented 2 years ago

A sensible semantics is in the link from https://github.com/Mikolaj/horde-ad/issues/64#issuecomment-1323655853

This should be doable in some form. Either giving it four type-level lists as in the Tensorflow function or, orthotope-style, work on the outermost dimensions, transposing as needed, which is almost free in orthotope. However, I'm not sure how to handle mutliple dimensions to contract and/or batch in the latter API. That might require nested orthotope arrays, which is doable via Data.Array.Shaped, but being boxed, this can be slower than specifying lists of dimensions. Transforming between unboxed and nested boxed is a combination of transpose and ravel/unravel, so it's more noisy than just transpose. Another option is being less general than Tensorflow, with the user having an option to manually ravel/unravel to recover the generality.