Mikolaj / horde-ad

Higher Order Reverse Derivatives Efficiently - Automatic Differentiation library based on the paper "Provably correct, asymptotically efficient, higher-order reverse-mode automatic differentiation"
BSD 3-Clause "New" or "Revised" License
33 stars 6 forks source link

Investigate "‘p0’ is untouchable" that derails type reconstruction #87

Open Mikolaj opened 1 year ago

Mikolaj commented 1 year ago

Here

nestedGather :: forall r. ADReady r
             => TensorOf 2 r -> TensorOf 2 r
nestedGather t =
  tgather
          (2 :$ 2 :$ ZS)
          (tgather
                   (2 :$ 3 :$ 4 :$ 4 :$ ZS) t
                   (\(k1 :. k2 :. k3 :. ZI) -> k1 + k2 + k3 :. ZI))
          (\(i1 :. i2 :. ZI) -> i1 :. i2 :. i1 + i2 :. i2 :. ZI)

we have all the information needed to reconstruct all types, but it's not possible and the root cause is probably

test/simplified/TestGatherSimplified.hs:35:61: error:
    • Couldn't match type ‘p0’ with ‘1’ arising from a use of ‘:.’
    • ‘p0’ is untouchable
        inside the constraints: n4 ~ 0
        bound by a pattern with pattern synonym:
                   ZI :: forall (n :: GHC.Num.Natural.Natural) i.
                         () =>
                         (n ~ 0) => Index n i,
                 in a lambda abstraction
        at test/simplified/TestGatherSimplified.hs:35:41-42
    • In the expression: k1 + k2 + k3 :. ZI

Let's investigate where it comes from. E.g., whether the problem comes from the KnownNat constraint in our :. pattern synonym. If the code type-checks fine when using ::: instead (and a newtype wrapper) then this may be the case.

This is important in order to lower the number of type applications the end user of our library needs to be adding. Here is a version that type-checks thanks to type applications. It's not too bad, but the code is now too complex and the user writes all list in their fully sized form, instead of using the simpler IsList syntax, which sadly forgets all type-level Nat information.

nestedGather :: forall r. ADReady r
             => TensorOf 2 r -> TensorOf 2 r
nestedGather t =
  tgather @r @2
          (2 :$ 2 :$ ZS)
          (tgather @r @3
                   (2 :$ 3 :$ 4 :$ 4 :$ ZS) t
                   (\(k1 :. k2 :. k3 :. ZI) -> k1 + k2 + k3 :. ZI))
          (\(i1 :. i2 :. ZI) -> i1 :. i2 :. i1 + i2 :. i2 :. ZI)