clash-lang / ghc-typelits-extra

Extra type-level operations on GHC.TypeLits.Nat and a custom solver
Other
16 stars 9 forks source link

Derive `Max (a + n) (b + n) = Max a b + n` #35

Open isovector opened 3 years ago

isovector commented 3 years ago

I'm attempting to implement this myself, rewriting Max a b + n as Max (a + n) (b + n) --- but the existing tests fail to compile when I add a normaliseNat case for the plus constructor. Eg, the following errors are some of the sort:

home/sandy/prj/ghc-typelits-extra/tests/Main.hs:172:12: error:
    • Couldn't match type ‘Max (n + 1) 1’ with ‘1 + n’
      Expected type: Proxy (Max (n + 1) 1) -> Proxy (1 + n)
        Actual type: Proxy (1 + n) -> Proxy (1 + n)
    • In the expression: id
      In an equation for ‘test49’: test49 _ = id
    • Relevant bindings include
        test49 :: Proxy n -> Proxy (Max (n + 1) 1) -> Proxy (1 + n)
          (bound at tests/Main.hs:172:1)
    |       
172 | test49 _ = id
    |            ^^

/home/sandy/prj/ghc-typelits-extra/tests/Main.hs:178:12: error:
    • Couldn't match type ‘Max (n + 2) 1’ with ‘Max (2 + n) 2’
      Expected type: Proxy (Max (n + 2) 1) -> Proxy (Max (2 + n) 2)
        Actual type: Proxy (Max (2 + n) 2) -> Proxy (Max (2 + n) 2)
      NB: ‘Max’ is a non-injective type family
    • In the expression: id
      In an equation for ‘test50’: test50 _ = id
    • Relevant bindings include
        test50 :: Proxy n -> Proxy (Max (n + 2) 1) -> Proxy (Max (2 + n) 2)
          (bound at tests/Main.hs:178:1)
    |       
178 | test50 _ = id
    |    

I guess I've broken the usual natnormalisation that goes on? Is there an easy way to get it back?

Diff: https://github.com/clash-lang/ghc-typelits-extra/compare/master...isovector:plus-max?expand=1&w=1

christiaanb commented 3 years ago

Yeah, do

containsConstants (Add _ _)  = True

instead. That will stop ghc-typelits-extra from making judgements about addition.

isovector commented 3 years ago

That does it, cheers! PR incoming!

christiaanb commented 3 years ago

Does your PR handle?

test1 :: Proxy a -> Proxy b -> Proxy n -> Proxy (Max (a + n) (b + n)) -> Proxy (n + Max a b)
test1 _ _ _ = id
isovector commented 3 years ago

Nice catch --- no, it doesn't. I assumed commutativity would be handled by natnormalise?

christiaanb commented 3 years ago

Yeah. Sadly natnormalise doesn't rewrite inside equality constraints.

isovector commented 3 years ago

Dang. What do you think is the move here? Use nonDetCmpType to find a canonical ordering of plus terms?

isovector commented 3 years ago

Discovered CType and ordered based on that; your testcase passes now!

christiaanb commented 3 years ago

Nice, I can live what that solution for now.

It will still fail for

test2 :: Proxy a -> Proxy b -> Proxy x -> Proxy y -> Proxy (Max (a + (x * y)) (b + (y * x)) -> Proxy ((y * x) + Max a b)
test2 _ _ _ _ = id

We can call https://hackage.haskell.org/package/ghc-typelits-natnormalise-0.7.6/docs/GHC-TypeLits-Normalise-Unify.html#v:normaliseNatEverywhere to actually use the normalise solver inside of the extra solver where appropriate. But there are some intricacies there, so I'll simply build that on top of your PR once you've submitted it.

isovector commented 3 years ago

tl;dr: This patch doesn't actually help me in the real world. Am I expecting natnormalise to do too much?


My original use case, however, doesn't seem to work yet. The constraint I'm trying to solve:

((Max (((SizeOf word + SizeOf word) + 1) + 1)
     ((SizeOf word + 1) + 1)
 ) + 1)

<=?

((((Max (1 + (n + SizeOf word))
        ((Max (Max (((SizeOf word + SizeOf word) + 1) + 1)
                   ((SizeOf word + 1) + 1) + 1)
            n
        ) + 1
       ) + 1
   )
  ) - 1) - 1))

which under natnormalise should simplify to (IIUC):

((Max (((SizeOf word + SizeOf word) + 2))
     ((SizeOf word + 2))
 ) + 1)

<=?

((((Max (1 + (n + SizeOf word))
        ((Max ((Max (((SizeOf word + SizeOf word) + 2))
                   ((SizeOf word + 2))
               ) + 1)
            n
        ) + 1
       ) + 1
   )
  ) - 2)))

and now under the new rules, should be:

((Max (((SizeOf word + SizeOf word) + 2) + 1)
     ((SizeOf word + 2) + 1)
 ))

<=?

((((Max (1 + (n + SizeOf word) + 1 - 2)
        ((Max ((Max (((SizeOf word + SizeOf word) + 3) + 2 - 2)
                   ((SizeOf word + 3) + 2 - 2)
               ))
            (n + 2 - 2)
        )
       )
   )
  ))))

and then I'd expect the subtractions to cancel:


((Max (((SizeOf word + SizeOf word) + 2) + 1)
     ((SizeOf word + 2) + 1)
 ))

<=?

((((Max ((n + SizeOf word))
        ((Max ((Max (((SizeOf word + SizeOf word) + 3))
                   ((SizeOf word + 3))
               ))
            (n)
        )
       )
   )
  ))))

after some simplification by hand:

Max (SizeOf word + SizeOf word + 3) 
    (SizeOf word + 3)

<=?

Max (n + SizeOf word)
    (Max (Max (SizeOf word + SizeOf word + 3)
              (SizeOf word + 3)
         ) n
   )

and then

(SizeOf word + SizeOf word + 3) 

<=?

Max (n + SizeOf word)
    (Max (SizeOf word + SizeOf word + 3) n)

Here, we have (SizeOf word + SizeOf word + 3) on both sides, so this thing should just be True. But the solver appears to be stuck at the first step?