haskell / vector

An efficient implementation of Int-indexed arrays (both mutable and immutable), with a powerful loop optimisation framework .
Other
359 stars 141 forks source link

Not fusing unless monadic #438

Open WinstonHartnett opened 2 years ago

WinstonHartnett commented 2 years ago

From this issue on ghc:

The StackOverflow question "Is there any way to inline a recursive function" includes roughly the following trick by Matthew Pickering to inline the recursive function oldNTimes shown below:

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeApplications #-}

module Main (main) where

import GHC.TypeLits
import qualified Data.Vector.Unboxed as V

{-# INLINE incAll #-}
incAll :: V.Vector Int -> V.Vector Int
incAll = V.map (+ 1)

-- Old definition
oldNTimes :: Int -> (a -> a) -> a -> a
oldNTimes 0 f x = x
oldNTimes n f x = f (oldNTimes (n-1) f x)

-- New definition
class Unroll (n :: Nat) where
  nTimes :: (a -> a) -> a -> a

instance Unroll 0 where
  nTimes f x = x

instance {-# OVERLAPPABLE #-} Unroll (p - 1) => Unroll p where
  nTimes f x = f (nTimes @(p - 1) f x)

main :: IO ()
main = do
  let size = 100000000 :: Int
  let array = V.replicate size 0 :: V.Vector Int
  print $ V.sum (nTimes @64 incAll array)
  -- print $ V.sum (oldNTimes 64 incAll array)

On GHC 8.2.2, nTimes takes 38.1ms compared to oldNTimes' 25.5s. But on 9.2.2, this doesn't fuse and both nTimes and oldNTimes run in 4.3s (113x slower). GHC is inlining the recursive calls.

And, for some reason, lifted incAll and nTimes run in 38ms on 9.2.2.

...
{-# INLINE incAllM #-}
incAllM :: Monad m => V.Vector Int -> m (V.Vector Int)
incAllM = pure . V.map (+ 1)

class UnrollM (n :: Nat) where
  nTimesM :: Monad m => (a -> m a) -> a -> m a

instance UnrollM 0 where
  nTimesM f x = pure x

instance {-# OVERLAPPABLE #-} UnrollM (p - 1) => UnrollM p where
  nTimesM f x = f =<< nTimesM @(p - 1) f x

main :: IO ()
main = do
  let size = 100000000 :: Int
  let array = V.replicate size 0 :: V.Vector Int
  print . V.sum =<< nTimesM @64 incAllM array

Reproduction project here

WinstonHartnett commented 2 years ago

Seems related to #416.

WinstonHartnett commented 2 years ago

Actually, incAll is being inlined before specialization b/c it's a 0-arity binding (see ghc issue above). incAll x = V.map (+1) x restores performance.