clash-lang / ghc-typelits-knownnat

Derive KnownNat constraints from other KnownNat constraints
Other
14 stars 10 forks source link

The plugin sometimes doesn't look through type aliases #53

Open leonschoorl opened 4 months ago

leonschoorl commented 4 months ago
{-# LANGUAGE TemplateHaskell,MultiParamTypeClasses,FlexibleInstances,TypeOperators,NoStarIsType #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Bug where
import Prelude (Bool(..),undefined)
import Data.Proxy (Proxy(Proxy))
import GHC.TypeNats (natVal)
import GHC.TypeLits (Nat, KnownNat, type (*), type (+))
import GHC.TypeLits.KnownNat (KnownNat1(..), SNatKn(..), nameToSymbol)

-- like in clash-prelude:
data Vec (n::Nat) a
repeat :: KnownNat n => a -> Vec n a
repeat = undefined

type Foo (x::Nat) = x
type NatTimes2 (x :: Nat) = Foo (x * 2)

{- this fails with:
    • Could not deduce (KnownNat (Foo (x * 2)))
        arising from a use of ‘repeat’
      from the context: KnownNat x
-}
times2 :: KnownNat x => Vec (NatTimes2 x) Bool
times2 = repeat False

Interestingly it works fine for Foo (x + 2).

Other examples that work fine ```haskell type NatPlus2 (x :: Nat) = Foo (x + 2) type FooNested (x :: Nat) = Foo x type NatTimes2' (x :: Nat) = (Foo x) * 2 type NatTimes1 (x :: Nat) = Foo (x * 1) type NatTimes2Direct (x :: Nat) = x * 2 type NatPlus2Direct (x :: Nat) = x + 2 plus2 :: KnownNat x => Vec (NatPlus2 x) Bool plus2 = repeat False onlyFoo :: KnownNat x => Vec (Foo x) Bool onlyFoo = repeat False onlyFooNested :: KnownNat x => Vec (FooNested x) Bool onlyFooNested = repeat False times2' :: KnownNat x => Vec (NatTimes2' x) Bool times2' = repeat False times1 :: KnownNat x => Vec (NatTimes1 x) Bool times1 = repeat False times2D :: KnownNat x => Vec (NatTimes2Direct x) Bool times2D = repeat False plus2D :: KnownNat x => Vec (NatPlus2Direct x) Bool plus2D = repeat False ```

As a workaround defining a KnownNat1 instance works:

instance (KnownNat x) => KnownNat1 $(nameToSymbol ''Foo) x where
  natSing1 = SNatKn (natVal (Proxy @x))
  {-# NOINLINE natSing1 #-}

But it seems to me that shouldn't be necessary, as Foo is just an alias, and it automatically looks through it for Foo (x + 2)

rowanG077 commented 4 months ago

It's not only this plugin that has this issue. At least ghc-typelits-natnormalise does as well.

But just like you indicate it seems inconsistent when it does/does not do it correctly.

christiaanb commented 3 months ago

Ah, I probably need to add a coreView around here https://github.com/clash-lang/ghc-typelits-knownnat/blob/2e57de3b709dab085fb1657cf73d4f5e833229ee/src-ghc-9.4/GHC/TypeLits/KnownNat/Solver.hs#L385-L387 as we look up the instance by name of the type https://github.com/clash-lang/ghc-typelits-knownnat/blob/2e57de3b709dab085fb1657cf73d4f5e833229ee/src-ghc-9.4/GHC/TypeLits/KnownNat/Solver.hs#L392-L396

lmbollen commented 3 months ago

The current master of clash-prelude contains a few type synonyms where we encounter the same issue. This pull request https://github.com/clash-lang/clash-compiler/pull/2734

Introduces:

-- | Gets time in 'Picoseconds' from time in 'Seconds'
type Seconds      (s  :: Nat) = Milliseconds (1000 * s)
-- | Gets time in 'Picoseconds' from time in 'Milliseconds'
type Milliseconds (ms :: Nat) = Microseconds (1000 * ms)
-- | Gets time in 'Picoseconds' from time in 'Microseconds'
type Microseconds (us :: Nat) = Nanoseconds  (1000 * us)
-- | Gets time in 'Picoseconds' from time in 'Nanoseconds'
type Nanoseconds  (ns :: Nat) = Picoseconds  (1000 * ns)
-- | Gets time in 'Picoseconds' from time in picoseconds, essentially 'id'
type Picoseconds  (ps :: Nat) = ps

With this reproducer:

module Bug where

import Clash.Prelude

-- Works
periodPs :: SNat ps -> SNat (Picoseconds ps)
periodPs SNat = SNat

-- Doesn't work
periodNs :: SNat ns -> SNat (Nanoseconds ns)
periodNs SNat = SNat

-- Doesn't work
periodMs :: SNat ms -> SNat (Milliseconds ms)
periodMs SNat = SNat

-- Doesn't work
periodS :: SNat s -> SNat (Seconds s)
periodS SNat = SNat

We receive:

src/Bug.hs:11:17: error:
    • Could not deduce (KnownNat (Picoseconds (1000 * ns)))
        arising from a use of ‘SNat’
      from the context: KnownNat ns
        bound by a pattern with constructor:
                   SNat :: forall (n :: Nat). KnownNat n => SNat n,
                 in an equation for ‘periodNs’
        at src/Bug.hs:11:10-13
    • In the expression: SNat
      In an equation for ‘periodNs’: periodNs SNat = SNat
   |
11 | periodNs SNat = SNat
   |                 ^^^^

src/Bug.hs:15:17: error:
    • Could not deduce (KnownNat
                          (Picoseconds (1000 * (1000 * (1000 * ms)))))
        arising from a use of ‘SNat’
      from the context: KnownNat ms
        bound by a pattern with constructor:
                   SNat :: forall (n :: Nat). KnownNat n => SNat n,
                 in an equation for ‘periodMs’
        at src/Bug.hs:15:10-13
    • In the expression: SNat
      In an equation for ‘periodMs’: periodMs SNat = SNat
   |
15 | periodMs SNat = SNat
   |                 ^^^^

src/Bug.hs:19:16: error:
    • Could not deduce (KnownNat
                          (Picoseconds (1000 * (1000 * (1000 * (1000 * s))))))
        arising from a use of ‘SNat’
      from the context: KnownNat s
        bound by a pattern with constructor:
                   SNat :: forall (n :: Nat). KnownNat n => SNat n,
                 in an equation for ‘periodS’
        at src/Bug.hs:19:9-12
    • In the expression: SNat
      In an equation for ‘periodS’: periodS SNat = SNat
   |
19 | periodS SNat = SNat