clash-lang / clash-compiler

Haskell to VHDL/Verilog/SystemVerilog compiler
https://clash-lang.org/
Other
1.4k stars 147 forks source link

Behaves differently: `zipWith3 f a b c` / `f <$> a <*> b <*> c` #2723

Open martijnbastiaan opened 1 month ago

martijnbastiaan commented 1 month ago

That is, the following causes Clash to generate one module for bbWrapper:

topEntity :: Vec 4 (Signal System Bit)
topEntity = zipWith3 bbWrapper indicesI indicesI indicesI
{-# CLASH_OPAQUE topEntity #-}

But the following generates tries to generate 4:

topEntity :: Vec 4 (Signal System Bit)
topEntity = bbWrapper <$> indicesI <*> indicesI <*> indicesI
{-# CLASH_OPAQUE topEntity #-}

Reproducer (note that this reproducer will fail to error if https://github.com/clash-lang/clash-compiler/issues/2722 gets fixed):

{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}

-- | Test @zipWith3 f a b c@ behaves the same as @f <$> a <*> b <*> c@ when it
-- comes to specialization caches.
module T2723 where

import Clash.Explicit.Prelude

import Clash.Annotations.Primitive (Primitive(..))
import Clash.Backend (blockDecl)
import Data.Monoid (Ap(getAp))

import qualified Clash.Netlist.Types as N
import qualified Clash.Netlist.Id as Id

bbTF :: N.TemplateFunction
bbTF = N.TemplateFunction used valid $ \_bbCtx -> do
  x <- Id.make "x"

  () <- case Id.toText x of
    "x" -> pure ()
    xName -> error $ "Unexpected name: " <> show xName <> ". Expected: x."

  getAp $ blockDecl x [N.NetDecl Nothing x N.Bit]
 where
  used    = [0,1]
  valid _ = True

{-# ANN bb (InlinePrimitive [minBound..] "[ { \"BlackBox\" : { \"name\" : \"T2723.bb\", \"kind\": \"Declaration\", \"workInfo\": \"Always\", \"format\": \"Haskell\", \"templateFunction\": \"T2723.bbTF\"}} ]") #-}
bb :: Signal System Bit
bb = pure low
{-# CLASH_OPAQUE bb #-}

bbWrapper :: Index n -> Index n -> Index n -> Signal System Bit
bbWrapper !_ !_ !_ = bb
{-# CLASH_OPAQUE bbWrapper #-}

-- FAILS:
topEntity :: Vec 4 (Signal System Bit)
topEntity = bbWrapper <$> indicesI <*> indicesI <*> indicesI
{-# CLASH_OPAQUE topEntity #-}

-- OK:
-- topEntity :: Vec 4 (Signal System Bit)
-- topEntity v = zipWith3 bbWrapper v indicesI indicesI
-- {-# CLASH_OPAQUE topEntity #-}

Gives:

$ rm -rf vhdl/ && cabal run clash -- -itests/shouldwork/Issues T2723 --vhdl -Wall -Werror -DCLASH_OPAQUE=OPAQUE
Resolving dependencies...
Loaded package environment from /home/martijn/code/clash-compiler/.ghc.environment.x86_64-linux-9.4.8
GHC: Setting up GHC took: 0.250s
GHC: Compiling and loading modules took: 1.229s
Hint: Interpreting T2723.bbTF
Clash: Parsing and compiling primitives took 1.550s
GHC+Clash: Loading modules cumulatively took 3.170s
Clash: Compiling T2723.topEntity
Clash: Normalization took 0.005s
Clash: Netlist generation took 0.000s

<no location info>: error:
    Clash error call:
    Unexpected name: "x_0". Expected: x.
    CallStack (from HasCallStack):
      error, called at tests/shouldwork/Issues/T2723.hs:23:14 in main:T2723

I'm somewhat undecided whether this is a bug or an enhancement, so I've labelled them as both.

christiaanb commented 1 month ago

No specialisation takes place, so it cannot be a specialisation cache issue.

The "problem" is caseCon:

case Clash.Sized.Vector.zipWith @(Clash.Sized.Internal.Index.Index[8214565720323788380] 4)
       @(Clash.Sized.Internal.Index.Index[8214565720323788380] 4)
       @(Clash.Sized.Internal.Index.Index[8214565720323788380] 4
       -> Clash.Signal.Internal.Signal[8214565720323788347]
            "System"
            Clash.Sized.Internal.BitVector.Bit[8214565720323788378])
       @4
       (T2723.bbWrapper[8214565720323817153][GlobalId] @4)
       (Clash.Sized.Vector.imap @4 @GHC.Tuple.()[3746994889972252672]
          @(Clash.Sized.Internal.Index.Index[8214565720323788380] 4)
          T2723.topEntity1[8214565720323827100][GlobalId]
          (GHC.Base.const[8214565720323786588][GlobalId]
             @(Clash.Sized.Internal.Index.Index[8214565720323788380] 4)
             @GHC.Tuple.()[3746994889972252672])
          (Clash.Sized.Vector.replicate @4 @GHC.Tuple.()[3746994889972252672]
             (Clash.Promoted.Nat.SNat[8214565720323788264] @4 T2723.topEntity1[8214565720323827100][GlobalId])
             GHC.Tuple.()[3963167672086036480]))
       (Clash.Sized.Vector.imap @4 @GHC.Tuple.()[3746994889972252672]
          @(Clash.Sized.Internal.Index.Index[8214565720323788380] 4)
          T2723.topEntity1[8214565720323827100][GlobalId]
          (GHC.Base.const[8214565720323786588][GlobalId]
             @(Clash.Sized.Internal.Index.Index[8214565720323788380] 4)
             @GHC.Tuple.()[3746994889972252672])
          (Clash.Sized.Vector.replicate @4 @GHC.Tuple.()[3746994889972252672]
             (Clash.Promoted.Nat.SNat[8214565720323788264] @4 T2723.topEntity1[8214565720323827100][GlobalId])
             GHC.Tuple.()[3963167672086036480])) of
  Clash.Sized.Vector.Cons[8214565720323788392] m[2]
    (_co_[5] :: GHC.Prim.~#[3674937295934324842]
                  GHC.Num.Natural.Natural[3674937295934324786]
                  GHC.Num.Natural.Natural[3674937295934324786]
                  4
                  (GHC.TypeNats.+[3674937295934325540] 3 1))
      (el[6] :: Clash.Sized.Internal.Index.Index[8214565720323788380] 4
            -> Clash.Signal.Internal.Signal[8214565720323788347]
                 "System"
                 Clash.Sized.Internal.BitVector.Bit[8214565720323788378])
      (res[3] :: Clash.Sized.Vector.Vec[8214565720323788389]
                   3
                   (Clash.Sized.Internal.Index.Index[8214565720323788380] 4
                   -> Clash.Signal.Internal.Signal[8214565720323788347]
                        "System"
                        Clash.Sized.Internal.BitVector.Bit[8214565720323788378])) ->
    case res[3][LocalId] of
      Clash.Sized.Vector.Cons[8214565720323788392] m[2]
        (_co_[5] :: GHC.Prim.~#[3674937295934324842]
                      GHC.Num.Natural.Natural[3674937295934324786]
                      GHC.Num.Natural.Natural[3674937295934324786]
                      3
                      (GHC.TypeNats.+[3674937295934325540] 2 1))
          (el[6] :: Clash.Sized.Internal.Index.Index[8214565720323788380] 4
                -> Clash.Signal.Internal.Signal[8214565720323788347]
                     "System"
                     Clash.Sized.Internal.BitVector.Bit[8214565720323788378])
          (res[3] :: Clash.Sized.Vector.Vec[8214565720323788389]
                       2
                       (Clash.Sized.Internal.Index.Index[8214565720323788380] 4
                       -> Clash.Signal.Internal.Signal[8214565720323788347]
                            "System"
                            Clash.Sized.Internal.BitVector.Bit[8214565720323788378])) ->
        case res[3][LocalId] of
          Clash.Sized.Vector.Cons[8214565720323788392] m[2]
            (_co_[5] :: GHC.Prim.~#[3674937295934324842]
                          GHC.Num.Natural.Natural[3674937295934324786]
                          GHC.Num.Natural.Natural[3674937295934324786]
                          2
                          (GHC.TypeNats.+[3674937295934325540] 1 1))
              (el[6] :: Clash.Sized.Internal.Index.Index[8214565720323788380] 4
                    -> Clash.Signal.Internal.Signal[8214565720323788347]
                         "System"
                         Clash.Sized.Internal.BitVector.Bit[8214565720323788378])
              (res[3] :: Clash.Sized.Vector.Vec[8214565720323788389]
                           1
                           (Clash.Sized.Internal.Index.Index[8214565720323788380] 4
                           -> Clash.Signal.Internal.Signal[8214565720323788347]
                                "System"
                                Clash.Sized.Internal.BitVector.Bit[8214565720323788378])) ->
            case res[3][LocalId] of
              Clash.Sized.Vector.Cons[8214565720323788392] m[2]
                (_co_[5] :: GHC.Prim.~#[3674937295934324842]
                              GHC.Num.Natural.Natural[3674937295934324786]
                              GHC.Num.Natural.Natural[3674937295934324786]
                              1
                              (GHC.TypeNats.+[3674937295934325540] 0 1))
                  (el[6] :: Clash.Sized.Internal.Index.Index[8214565720323788380] 4
                        -> Clash.Signal.Internal.Signal[8214565720323788347]
                             "System"
                             Clash.Sized.Internal.BitVector.Bit[8214565720323788378])
                  (res[3] :: Clash.Sized.Vector.Vec[8214565720323788389]
                               0
                               (Clash.Sized.Internal.Index.Index[8214565720323788380] 4
                               -> Clash.Signal.Internal.Signal[8214565720323788347]
                                    "System"
                                    Clash.Sized.Internal.BitVector.Bit[8214565720323788378])) ->
                el[6][LocalId] x[1][LocalId]
Result:
let
  res[13] :: Clash.Sized.Vector.Vec[8214565720323788389]
               3
               (Clash.Sized.Internal.Index.Index[8214565720323788380] 4
               -> Clash.Signal.Internal.Signal[8214565720323788347]
                    "System"
                    Clash.Sized.Internal.BitVector.Bit[8214565720323788378])
  = Clash.Sized.Vector.zipWith @(Clash.Sized.Internal.Index.Index[8214565720323788380] 4)
      @(Clash.Sized.Internal.Index.Index[8214565720323788380] 4)
      @(Clash.Sized.Internal.Index.Index[8214565720323788380] 4
      -> Clash.Signal.Internal.Signal[8214565720323788347]
           "System"
           Clash.Sized.Internal.BitVector.Bit[8214565720323788378])
      @3
      ((ds[7205759403792831058] :: Clash.Sized.Internal.Index.Index[8214565720323788380] 4) ->
      (ds1[7205759403792831059] :: Clash.Sized.Internal.Index.Index[8214565720323788380] 4) ->
      (ds2[7205759403792831060] :: Clash.Sized.Internal.Index.Index[8214565720323788380] 4) ->
      case ds[7205759403792831058][LocalId] 
          Clash.Sized.Internal.Index.Index[8214565720323788380] 4
          ~
          GHC.Num.Integer.Integer[3674937295934324784] of
        _ ->
          case ds1[7205759403792831059][LocalId] 
              Clash.Sized.Internal.Index.Index[8214565720323788380] 4
              ~
              GHC.Num.Integer.Integer[3674937295934324784] of
            _ ->
              case ds2[7205759403792831060][LocalId] 
                  Clash.Sized.Internal.Index.Index[8214565720323788380] 4
                  ~
                  GHC.Num.Integer.Integer[3674937295934324784] of
                _ ->
                  T2723.bb)
      (Clash.Sized.Vector.imap_go @4 @3 @GHC.Tuple.()[3746994889972252672]
         @(Clash.Sized.Internal.Index.Index[8214565720323788380] 4)
         ((x[8214565720323828149] :: Clash.Sized.Internal.Index.Index[8214565720323788380] 4) ->
         (ds[8214565720323828150] :: GHC.Tuple.()[3746994889972252672]) ->
         x[8214565720323828149][LocalId])
         (Clash.Sized.Vector.Cons[8214565720323788392] @3 @GHC.Tuple.()[3746994889972252672] @2 _CO_
            GHC.Tuple.()[3963167672086036480]
            (Clash.Sized.Vector.Cons[8214565720323788392] @2 @GHC.Tuple.()[3746994889972252672] @1 _CO_
               GHC.Tuple.()[3963167672086036480]
               (Clash.Sized.Vector.Cons[8214565720323788392] @1 @GHC.Tuple.()[3746994889972252672] @0 _CO_
                  GHC.Tuple.()[3963167672086036480]
                  (Clash.Sized.Vector.Nil[8214565720323788393] @0 @GHC.Tuple.()[3746994889972252672] _CO_))))
         (Clash.Sized.Internal.Index.+# @4 4 (Clash.Sized.Internal.Index.fromInteger# @4 4 0)
            (Clash.Sized.Internal.Index.fromInteger# @4 4 1)))
      (Clash.Sized.Vector.tail @3 @(Clash.Sized.Internal.Index.Index[8214565720323788380] 4)
         (Clash.Sized.Vector.Cons[8214565720323788392] @4
            @(Clash.Sized.Internal.Index.Index[8214565720323788380] 4)
            @3
            _CO_
            (((x[8214565720323828149] :: Clash.Sized.Internal.Index.Index[8214565720323788380] 4) ->
               (ds[8214565720323828150] :: GHC.Tuple.()[3746994889972252672]) ->
               x[8214565720323828149][LocalId])
               (Clash.Sized.Internal.Index.fromInteger# @4 4 0)
               GHC.Tuple.()[3963167672086036480])
            (Clash.Sized.Vector.imap_go @4 @3 @GHC.Tuple.()[3746994889972252672]
               @(Clash.Sized.Internal.Index.Index[8214565720323788380] 4)
               ((x[8214565720323828149] :: Clash.Sized.Internal.Index.Index[8214565720323788380] 4) ->
               (ds[8214565720323828150] :: GHC.Tuple.()[3746994889972252672]) ->
               x[8214565720323828149][LocalId])
               (Clash.Sized.Vector.Cons[8214565720323788392] @3 @GHC.Tuple.()[3746994889972252672] @2 _CO_
                  GHC.Tuple.()[3963167672086036480]
                  (Clash.Sized.Vector.Cons[8214565720323788392] @2 @GHC.Tuple.()[3746994889972252672] @1 _CO_
                     GHC.Tuple.()[3963167672086036480]
                     (Clash.Sized.Vector.Cons[8214565720323788392] @1 @GHC.Tuple.()[3746994889972252672] @0 _CO_
                        GHC.Tuple.()[3963167672086036480]
                        (Clash.Sized.Vector.Nil[8214565720323788393] @0 @GHC.Tuple.()[3746994889972252672] _CO_))))
               (Clash.Sized.Internal.Index.+# @4 4 (Clash.Sized.Internal.Index.fromInteger# @4 4 0)
                  (Clash.Sized.Internal.Index.fromInteger# @4 4 1)))))
in case res[13][LocalId] of
     Clash.Sized.Vector.Cons[8214565720323788392] m[24]
       (_co_[17] :: GHC.Prim.~#[3674937295934324842]
                      GHC.Num.Natural.Natural[3674937295934324786]
                      GHC.Num.Natural.Natural[3674937295934324786]
                      3
                      (GHC.TypeNats.+[3674937295934325540] 2 1))
         (el[19] :: Clash.Sized.Internal.Index.Index[8214565720323788380] 4
                -> Clash.Signal.Internal.Signal[8214565720323788347]
                     "System"
                     Clash.Sized.Internal.BitVector.Bit[8214565720323788378])
         (res[31] :: Clash.Sized.Vector.Vec[8214565720323788389]
                       2
                       (Clash.Sized.Internal.Index.Index[8214565720323788380] 4
                       -> Clash.Signal.Internal.Signal[8214565720323788347]
                            "System"
                            Clash.Sized.Internal.BitVector.Bit[8214565720323788378])) ->
       case res[31][LocalId] of
         Clash.Sized.Vector.Cons[8214565720323788392] m[32]
           (_co_[21] :: GHC.Prim.~#[3674937295934324842]
                          GHC.Num.Natural.Natural[3674937295934324786]
                          GHC.Num.Natural.Natural[3674937295934324786]
                          2
                          (GHC.TypeNats.+[3674937295934325540] 1 1))
             (el[23] :: Clash.Sized.Internal.Index.Index[8214565720323788380] 4
                    -> Clash.Signal.Internal.Signal[8214565720323788347]
                         "System"
                         Clash.Sized.Internal.BitVector.Bit[8214565720323788378])
             (res[39] :: Clash.Sized.Vector.Vec[8214565720323788389]
                           1
                           (Clash.Sized.Internal.Index.Index[8214565720323788380] 4
                           -> Clash.Signal.Internal.Signal[8214565720323788347]
                                "System"
                                Clash.Sized.Internal.BitVector.Bit[8214565720323788378])) ->
           case res[39][LocalId] of
             Clash.Sized.Vector.Cons[8214565720323788392] m[40]
               (_co_[25] :: GHC.Prim.~#[3674937295934324842]
                              GHC.Num.Natural.Natural[3674937295934324786]
                              GHC.Num.Natural.Natural[3674937295934324786]
                              1
                              (GHC.TypeNats.+[3674937295934325540] 0 1))
                 (el[27] :: Clash.Sized.Internal.Index.Index[8214565720323788380] 4
                        -> Clash.Signal.Internal.Signal[8214565720323788347]
                             "System"
                             Clash.Sized.Internal.BitVector.Bit[8214565720323788378])
                 (res[47] :: Clash.Sized.Vector.Vec[8214565720323788389]
                               0
                               (Clash.Sized.Internal.Index.Index[8214565720323788380] 4
                               -> Clash.Signal.Internal.Signal[8214565720323788347]
                                    "System"
                                    Clash.Sized.Internal.BitVector.Bit[8214565720323788378])) ->
               el[27][LocalId] x[1]

where it inlines the definition of bbWrapper. The issue is actually that the custom evaluator for zipWith is too strict in its function argument. I guess we need to extend https://github.com/clash-lang/clash-compiler/blob/a245bb8fd9bf589624ed83b55ef678e0fdb89b04/clash-ghc/src-ghc/Clash/GHC/Evaluator/Primitive.hs#L192-L216 to also work for zipWith.

martijnbastiaan commented 1 month ago

Why do we have a blackbox for zipWith to begin with? Seems like just another code path we need to maintain.

In this specific case the behavior of the bb is better btw, though correctness obviously beats prettiness.

christiaanb commented 1 month ago

Haskell-to-HDL compile speed. Without them, the clash compiler would have to unfold/unroll all functions over vector. That could increase compile times by 100x.