tweag / linear-base

Standard library for linear types in Haskell.
MIT License
334 stars 37 forks source link

Optimizing linear arrays #328

Open utdemir opened 3 years ago

utdemir commented 3 years ago

I spent some time benchmarking our linear collections, and wanted to create an issue with my findings.

Reading from our linear arrays seems to almost an order of magnitude slower than reading from a vector.

I benchmarked these two functions:

   linear :: Array.Linear.Array Int %1-> ()
   linear hm =
     hm
       Linear.& Array.Linear.size
       Linear.& \(Linear.Ur sz, arr) -> arr
       Linear.& go 0 sz
    where
     go :: Int -> Int -> Array.Linear.Array Int %1-> ()
     go start end arr
       | start < end =
           Array.Linear.unsafeGet start arr
             Linear.& \(Linear.Ur i, arr) -> i `Linear.seq` go (start + 1) end arr
       | otherwise = arr `Linear.lseq` ()

   dataVector :: Data.Vector.Vector Int -> ()
   dataVector v =
     let sz = Data.Vector.length v
      in go 0 sz
    where
     go :: Int -> Int -> ()
     go start end
       | start < end =
           (v Data.Vector.! start) `seq`  go (start + 1) end
       | otherwise = ()

And here's the result on my system (this also includes allocating the array/vector with an initial element, but I measured that that is not dominating the runtime):

benchmarked arrays/reads/Data.Array.Mutable.Linear
time                 27.26 ms   (26.39 ms .. 28.05 ms)
                     0.997 R²   (0.996 R² .. 0.999 R²)
mean                 29.13 ms   (28.44 ms .. 30.28 ms)
std dev              1.972 ms   (1.178 ms .. 2.606 ms)
variance introduced by outliers: 26% (moderately inflated)

benchmarked arrays/reads/Data.Vector
time                 3.841 ms   (3.796 ms .. 3.887 ms)
                     0.999 R²   (0.998 R² .. 0.999 R²)
mean                 3.971 ms   (3.943 ms .. 4.011 ms)
std dev              107.8 μs   (80.19 μs .. 173.3 μs)

When looking at the relevant -ddump-simpl output (with -O), it can be seen that the generated code of Data.Vector is a pretty efficient loop:

$wbReads1_raMu
  :: GHC.Prim.Int# -> GHC.Prim.Int# -> GHC.Prim.Array# Int -> ()
[GblId, Arity=3, Str=<L,U><L,U><L,U>, Unf=OtherCon []]
$wbReads1_raMu
  = \ (ww_sat7 :: GHC.Prim.Int#)
      (ww1_sat8 :: GHC.Prim.Int#)
      (ww2_sat9 :: GHC.Prim.Array# Int) ->
      join {
        exit_X0 [Dmd=<L,C(U)>] :: GHC.Prim.Int# -> ()
        [LclId[JoinId(1)], Arity=1, Str=<B,U>b]
        exit_X0 (ww3_sasX [OS=OneShot] :: GHC.Prim.Int#)
          = case lvl4_raM6 ww3_sasX ww1_sat8 of wild_00 { } } in
      join {
        exit1_X1 [Dmd=<L,C(U)>] :: GHC.Prim.Int# -> ()
        [LclId[JoinId(1)], Arity=1, Str=<B,U>b]
        exit1_X1 (ww3_sasX [OS=OneShot] :: GHC.Prim.Int#)
          = case lvl4_raM6 ww3_sasX ww1_sat8 of wild_00 { } } in
      joinrec {
        $wgo2_sat3 [InlPrag=NOUSERINLINE[2], Occ=LoopBreaker]
          :: GHC.Prim.Int# -> GHC.Prim.Int# -> ()
        [LclId[JoinId(2)], Arity=2, Str=<L,U><L,U>, Unf=OtherCon []]
        $wgo2_sat3 (ww3_sasX :: GHC.Prim.Int#) (ww4_sat1 :: GHC.Prim.Int#)
          = case GHC.Prim.<# ww3_sasX ww4_sat1 of {
              __DEFAULT -> GHC.Tuple.();
              1# ->
                case GHC.Prim.>=# ww3_sasX 0# of {
                  __DEFAULT -> jump exit_X0 ww3_sasX;
                  1# ->
                    case GHC.Prim.<# ww3_sasX ww1_sat8 of {
                      __DEFAULT -> jump exit1_X1 ww3_sasX;
                      1# ->
                        case GHC.Prim.indexArray#
                               @Int ww2_sat9 (GHC.Prim.+# ww_sat7 ww3_sasX)
                        of
                        { (# ipv_a97a #) ->
                        case ipv_a97a of { GHC.Types.I# ipv1_s8LP ->
                        jump $wgo2_sat3 (GHC.Prim.+# ww3_sasX 1#) ww4_sat1
                        }
                        }
                    }
                }
            }; } in
      jump $wgo2_sat3 0# ww1_sat8

However, the generated Core for our linear arrays is not as pretty:

Rec {
-- RHS size: {terms: 47, types: 183, coercions: 8, joins: 0/0}
$wgo1_raMr
  :: GHC.Prim.Int#
     -> GHC.Prim.Int#
     -> Data.Array.Mutable.Unlifted.Linear.Array# Int
     %1 -> ()
[GblId, Arity=3, Str=<L,U><L,U><L,U>, Unf=OtherCon []]
$wgo1_raMr
  = \ (ww_sasC :: GHC.Prim.Int#)
      (ww1_sasG :: GHC.Prim.Int#)
      (ww2_sasK :: Data.Array.Mutable.Unlifted.Linear.Array# Int) ->
      case GHC.Prim.<# ww_sasC ww1_sasG of {
        __DEFAULT ->
          case Unsafe.Coerce.unsafeEqualityProof
                 @(*)
                 @(GHC.Types.Any -> GHC.Types.Any)
                 @((Data.Array.Mutable.Unlifted.Linear.Array# Int -> () -> ())
                   %1 -> Data.Array.Mutable.Unlifted.Linear.Array# Int
                   %1 -> ()
                   %1 -> ())
          of
          { Unsafe.Coerce.UnsafeRefl co_a8Ic ->
          ((\ (x_a8Ib [OS=OneShot] :: GHC.Types.Any) -> x_a8Ib)
           `cast` (Sub (Sym co_a8Ic)
                   :: (GHC.Types.Any -> GHC.Types.Any)
                      ~R# ((Data.Array.Mutable.Unlifted.Linear.Array# Int -> () -> ())
                           %1 -> Data.Array.Mutable.Unlifted.Linear.Array# Int
                           %1 -> ()
                           %1 -> ())))
            (Data.Array.Mutable.Unlifted.Linear.lseq1 @Int @())
            ww2_sasK
            GHC.Tuple.()
          };
        1# ->
          case Unsafe.Coerce.unsafeEqualityProof
                 @(*)
                 @(GHC.Types.Any -> GHC.Types.Any)
                 @((Data.Array.Mutable.Unlifted.Linear.Array# Int
                    -> (# Linear.Ur Int,
                          Data.Array.Mutable.Unlifted.Linear.Array# Int #))
                   %1 -> Data.Array.Mutable.Unlifted.Linear.Array# Int
                   %1 -> (# Linear.Ur Int,
                            Data.Array.Mutable.Unlifted.Linear.Array# Int #))
          of
          { Unsafe.Coerce.UnsafeRefl co_a8py ->
          case ((\ (x_a8px [OS=OneShot] :: GHC.Types.Any) -> x_a8px)
                `cast` (Sub (Sym co_a8py)
                        :: (GHC.Types.Any -> GHC.Types.Any)
                           ~R# ((Data.Array.Mutable.Unlifted.Linear.Array# Int
                                 -> (# Linear.Ur Int,
                                       Data.Array.Mutable.Unlifted.Linear.Array# Int #))
                                %1 -> Data.Array.Mutable.Unlifted.Linear.Array# Int
                                %1 -> (# Linear.Ur Int,
                                         Data.Array.Mutable.Unlifted.Linear.Array# Int #))))
                 (\ (ds1_a8pC :: Data.Array.Mutable.Unlifted.Linear.Array# Int) ->
                    GHC.Magic.runRW#
                      @('GHC.Types.TupleRep
                          '[ 'GHC.Types.LiftedRep, 'GHC.Types.UnliftedRep])
                      @(# Linear.Ur Int, Data.Array.Mutable.Unlifted.Linear.Array# Int #)
                      (\ (s_a8pK [OS=OneShot] :: GHC.Prim.State# GHC.Prim.RealWorld) ->
                         case GHC.Prim.readArray#
                                @GHC.Prim.RealWorld
                                @Int
                                (ds1_a8pC
                                 `cast` (Data.Array.Mutable.Unlifted.Linear.N:Array#[0] <Int>_N
                                         :: Data.Array.Mutable.Unlifted.Linear.Array# Int
                                            ~R# GHC.Prim.MutableArray# GHC.Prim.RealWorld Int))
                                ww_sasC
                                s_a8pK
                         of
                         { (# ipv_a8pM, ipv1_a8pN #) ->
                         (# Data.Unrestricted.Internal.Ur.Ur @Int ipv1_a8pN, ds1_a8pC #)
                         }))
                 ww2_sasK
          of
          { (# ipv_a8pQ, ipv1_a8pR #) ->
          case ipv_a8pQ of { Linear.Ur i_a89p ->
          $wgo1_raMr (GHC.Prim.+# ww_sasC 1#) ww1_sasG ipv1_a8pR
          }
          }
          }
      }
end Rec }

I can not read core well, but I think I can see a few things:

I'll try to look into these further.

aspiwack commented 3 years ago

The coercions from non-linear functions to linear functions may be blocking optimisations as well. The calls to runRW# should be remove by the time we get to a low enough level, but in the meantime they may also hamper optimisation.

utdemir commented 3 years ago

The coercions from non-linear functions to linear functions may be blocking optimisations as well.

I think this is the current major issue. As an example, this simple code:

get ::  Int -> Array# a %1-> (# Ur a, Array# a #)
get (GHC.I# i) = Unsafe.toLinear go
  where
    go :: Array# a -> (# Ur a, Array# a #)
    go (Array# arr) =
      case GHC.runRW# (GHC.readArray# arr i) of
        (# _, ret #) -> (# Ur ret, Array# arr #)

Ends up becoming pretty big. As far as I can see, the call to Unsafe.toLinear ends up forcing its parameter to be allocated.

I also saw another similar instance caused by the use of Unsafe.toLinear in linear seq; but worked around it by reimplementing seq without using coerce PR #329 .

utdemir commented 3 years ago

After #329 and #330; Data.Array.Mutable.Linear.get is only three times slower than Data.Vector counterpart on GHC HEAD. I don't see a way to work around the the unsafeCoerce's blocking optimisations. I am not sure, but https://gitlab.haskell.org/ghc/ghc/-/issues/19542 might improve this.

There still are some very slow functions on Array API that I will investigate next. The first one is map, which is an order of magnitude slower than the vector counterpart. I think it also is because of the few coercions we have there.

aspiwack commented 3 years ago

Maybe you should commit your benchmarks in the repo in the meantime. They've already proved useful, so let's.

Divesh-Otwani commented 2 years ago

@utdemir Here are the benchmark results I get on my machine:

benchmarking arrays/toList/Data.Array.Mutable.Linear ... took 10.09 s, total 56 iterations
benchmarked arrays/toList/Data.Array.Mutable.Linear
time                 168.9 ms   (162.7 ms .. 175.4 ms)
                     0.998 R²   (0.994 R² .. 1.000 R²)
mean                 188.3 ms   (181.4 ms .. 198.8 ms)
std dev              14.52 ms   (7.924 ms .. 22.36 ms)
variance introduced by outliers: 19% (moderately inflated)

benchmarking arrays/toList/Data.Vector ... took 10.23 s, total 56 iterations
benchmarked arrays/toList/Data.Vector
time                 173.6 ms   (169.4 ms .. 176.2 ms)
                     0.999 R²   (0.998 R² .. 1.000 R²)
mean                 190.5 ms   (184.5 ms .. 202.4 ms)
std dev              13.83 ms   (7.228 ms .. 21.45 ms)
variance introduced by outliers: 19% (moderately inflated)

benchmarking arrays/map/Data.Array.Mutable.Linear ... took 45.81 s, total 56 iterations
benchmarked arrays/map/Data.Array.Mutable.Linear
time                 788.9 ms   (775.1 ms .. 803.4 ms)
                     0.999 R²   (0.999 R² .. 1.000 R²)
mean                 854.8 ms   (826.4 ms .. 919.7 ms)
std dev              71.13 ms   (25.08 ms .. 111.6 ms)
variance introduced by outliers: 28% (moderately inflated)

benchmarking arrays/map/Data.Vector ... took 45.97 s, total 56 iterations
benchmarked arrays/map/Data.Vector
time                 800.0 ms   (781.5 ms .. 812.7 ms)
                     0.999 R²   (0.999 R² .. 1.000 R²)
mean                 846.4 ms   (831.0 ms .. 874.0 ms)
std dev              34.90 ms   (17.92 ms .. 57.73 ms)

benchmarking arrays/reads/Data.Array.Mutable.Linear ... took 8.839 s, total 56 iterations
benchmarked arrays/reads/Data.Array.Mutable.Linear
time                 145.3 ms   (139.3 ms .. 149.5 ms)
                     0.998 R²   (0.997 R² .. 1.000 R²)
mean                 166.0 ms   (159.1 ms .. 177.7 ms)
std dev              15.09 ms   (8.732 ms .. 23.39 ms)
variance introduced by outliers: 28% (moderately inflated)

benchmarked arrays/reads/Data.Vector
time                 50.73 ms   (50.33 ms .. 51.21 ms)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 58.73 ms   (55.92 ms .. 67.09 ms)
std dev              7.917 ms   (3.382 ms .. 13.57 ms)
variance introduced by outliers: 52% (severely inflated)

Benchmark mutable-data: FINISH

I think the remaining focus is get as you've pointed out.

aspiwack commented 1 year ago

I wanted to have a look at this today, to understand what's going on better. Here are the array benchmarks today, on my machine (Ryzen something), with the recently released GHC 9.6.1

  arrays
    toList
      Data.Array.Mutable.Linear:                                                 OK (3.78s)
        116  ms ± 5.1 ms
      Data.Vector:                                                               OK (0.92s)
        122  ms ± 1.7 ms
    map
      Data.Array.Mutable.Linear:                                                 OK (1.51s)
        488  ms ±  15 ms
      Data.Vector:                                                               OK (7.69s)
        489  ms ±  19 ms
    reads
      Data.Array.Mutable.Linear:                                                 OK (0.82s)
        112  ms ± 3.6 ms
      Data.Vector:                                                               OK (7.14s)
        26.5 ms ± 1.3 ms

We see that successive reads are many times slower on our implementation than with Vector. The rest is on-par, as expected (map allocates a single (initially mutable) array in the Vector library too). Our times look a tiny bit better, it's possibly not a fluke. There is no tests for successive writes, which may be unfair for Vector if we find a pattern that inhibits fusion. (but we may want to add nevertheless, to make a point).

aspiwack commented 1 year ago

So, looking at the benchmark for reads, which reads every element in an array

The implementation of the benchmark is the following:

bReads :: Impls
bReads = Impls linear dataVector
  where
    linear :: Array.Linear.Array Int %1 -> ()
    linear hm =
      hm
        Linear.& Array.Linear.size
        Linear.& \(Linear.Ur sz, arr) ->
          arr
            Linear.& go 0 sz
      where
        go :: Int -> Int -> Array.Linear.Array Int %1 -> ()
        go start end arr
          | start < end =
              Array.Linear.unsafeGet start arr
                Linear.& \(Linear.Ur i, arr') -> i `Linear.seq` go (start + 1) end arr'
          | otherwise = arr `Linear.lseq` ()

    dataVector :: Data.Vector.Vector Int -> ()
    dataVector v =
      let sz = Data.Vector.length v
       in go 0 sz
      where
        go :: Int -> Int -> ()
        go start end
          | start < end =
              (v Data.Vector.! start) `seq` go (start + 1) end
          | otherwise = ()

Here is the optimised Core for the linear array version

Rec {
$wgo1
  = \ ww ww1 ww2 ->
      case <# ww ww1 of {
        __DEFAULT ->
          case unsafeEqualityProof of { UnsafeRefl co ->
          case ((\ _ b1 -> b1) `cast` <Co:12> :: ...) ww2 () of { () ->
          (##)
          }
          };
        1# ->
          case $wget ww ww2 of { (# ipv, ipv1 #) ->
          case ipv of { Ur i ->
          case i of { I# ipv2 -> $wgo1 (+# ww 1#) ww1 ipv1 }
          }
          }
      }
end Rec }

$wbReads
  = \ ww ->
      case unsafeEqualityProof of { UnsafeRefl co ->
      case ((\ ds1 ->
               (# Ur (I# (sizeofMutableArray# (ds1 `cast` <Co:2> :: ...))),
                  ds1 #))
            `cast` <Co:12> :: ...)
             ww
      of
      { (# ipv, ipv1 #) ->
      case ipv of { Ur sz -> case sz of { I# ww1 -> $wgo1 0# ww1 ipv1 } }
      }
      }

bReads1
  = \ hm ->
      case hm of { Array ww -> case $wbReads ww of { (# #) -> () } }

Here is the optimised Core for the Vector version:

bReads2
  = \ v ->
      case v of { Vector ww ww1 ww2 ->
      join { exit ww3 = case lvl43 ww3 ww1 of wild1 { } } in
      joinrec {
        $wgo2 ww3 ww4
          = case <# ww3 ww4 of {
              __DEFAULT -> ();
              1# ->
                case ltWord# (int2Word# ww3) (int2Word# ww1) of {
                  __DEFAULT -> jump exit ww3;
                  1# ->
                    case indexArray# ww2 (+# ww ww3) of { (# ipv #) ->
                    case ipv of { I# ipv1 -> jump $wgo2 (+# ww3 1#) ww4 }
                    }
                }
            }; } in
      jump $wgo2 0# ww1
      }

My immediate reaction is: the vector version has the go implemented as a recursive jump point, and the linear array version doesn't. I'm actually not sure why. It certainly has a performance implication, but I doubt it accounts for the 3-4x factor that we're seeing.

aspiwack commented 1 year ago

Besides the case on Ur, the main possible difference is the call to $wget. Here is the Core for it

$wget
  = \ @a ww eta ->
      case unsafeEqualityProof of { UnsafeRefl co ->
      ((\ ds ->
          runRW#
            (\ s ->
               case readArray# (ds `cast` <Co:2> :: ...) ww s of
               { (# ipv, ipv1 #) ->
               (# Ur ipv1, ds #)
               }))
       `cast` <Co:12> :: ...)
        eta
      }

This function is originally

Data.Array.Mutable.Unlifted.Linear.get (the intermediary calls become simplified), whose source is:

get :: Int -> Array# a %1 -> (# Ur a, Array# a #)
get (GHC.I# i) = Unsafe.toLinear go
  where
    go :: Array# a -> (# Ur a, Array# a #)
    go (Array# arr) =
      case GHC.runRW# (GHC.readArray# arr i) of
        (# _, ret #) -> (# Ur ret, Array# arr #)
{-# NOINLINE get #-} -- prevents the runRW# effect from being reordered
aspiwack commented 1 year ago

Allowing get to inline doesn't affect timing significantly. The generated Core for the benchmark doesn't look better (I'm not even sure why $wget got inlined to be honest):

Rec {
$wgo1
  = \ ww ww1 ww2 ->
      case <# ww ww1 of {
        __DEFAULT ->
          case unsafeEqualityProof of { UnsafeRefl co ->
          case ((\ _ b1 -> b1) `cast` <Co:12> :: ...) ww2 () of { () ->
          (##)
          }
          };
        1# ->
          case unsafeEqualityProof of { UnsafeRefl co ->
          case ((\ ds1 ->
                   runRW#
                     (\ s ->
                        case readArray# (ds1 `cast` <Co:2> :: ...) ww s of
                        { (# ipv, ipv1 #) ->
                        (# Ur ipv1, ds1 #)
                        }))
                `cast` <Co:12> :: ...)
                 ww2
          of
          { (# ipv, ipv1 #) ->
          case ipv of { Ur i ->
          case i of { I# ipv2 -> $wgo1 (+# ww 1#) ww1 ipv1 }
          }
          }
          }
      }
end Rec }

I'm not liking this cast though, having a look.

aspiwack commented 1 year ago

Honestly it looks like a bug, this is a representation coercion between an unrestricted arrow and a linear arrow. Which I don't think we've programmed (it probably doesn't pass the linter). Besides, why is there a representation coercion here? Where does it come from? Maybe the eta-expansion creates it?

Full Core:

-- RHS size: {terms: 20, types: 55, coercions: 14, joins: 0/0}
Data.Array.Mutable.Unlifted.Linear.$wget [InlPrag=NOINLINE]
  :: forall {a}. GHC.Int# -> Array# a %1 -> (# Ur a, Array# a #)
[GblId, Arity=2, Str=<L><L>, Unf=OtherCon []]
Data.Array.Mutable.Unlifted.Linear.$wget
  = \ (@a) (ww :: GHC.Int#) (eta :: Array# a) ->
      case Unsafe.Coerce.unsafeEqualityProof @GHC.Multiplicity @Many @One
      of
      { Unsafe.Coerce.UnsafeRefl co ->
      ((\ (ds [OS=OneShot] :: Array# a) ->
          GHC.runRW#
            @(GHC.TupleRep [GHC.LiftedRep, GHC.UnliftedRep])
            @(# Ur a, Array# a #)
            (\ (s [OS=OneShot] :: GHC.State# GHC.RealWorld) ->
               case GHC.readArray#
                      @GHC.Lifted
                      @GHC.RealWorld
                      @a
                      (ds
                       `cast` (Data.Array.Mutable.Unlifted.Linear.N:Array#[0] <a>_N
                               :: Array# a ~R# GHC.MutableArray# GHC.RealWorld a))
                      ww
                      s
               of
               { (# ipv, ipv1 #) ->
               (# Data.Unrestricted.Linear.Internal.Ur.Ur @a ipv1, ds #)
               }))
       `cast` (<Array# a>_R %(Sym co) ->_R <(# Ur a, Array# a #)>_R
               :: (Array# a -> (# Ur a, Array# a #))
                  ~R# (Array# a %1 -> (# Ur a, Array# a #))))
        eta
      }
aspiwack commented 1 year ago

Manually eta-expanding doesn't remove this dubious cast (the outer cast: I've just realise that there is an inner cast too, this one is normal).

I'm noticing that calls to Unsafe.toLinear, in that file, seem to produce an awful lot of these casts. This is strange, here is the definition of Unsafe.toLinear

-- | Converts an unrestricted function into a linear function
toLinear ::
  forall
    (r1 :: RuntimeRep)
    (r2 :: RuntimeRep)
    (a :: TYPE r1)
    (b :: TYPE r2)
    p
    x.
  (a %p -> b) %1 ->
  (a %x -> b)
toLinear f = case unsafeEqualityProof @p @x of
  UnsafeRefl -> f

Where does this representational cast come from?


Maybe part of the answer is that there is code blocking some optimisations around coercion pushing for linear types. But since the current policy is not to check linearity in Core beyond the desugarer, I should remove that cast.

Everything else looks perfectly reasonable, so I'll postpone figuring out more until I have improved on this optimisation in GHC. I think that part of the answer is that “read” is so fast that the slightest extra work can be quite significative.

One last thought: it seems that allocating the array is part of the test, I assume that allocating the array takes longer than reading from it (also note that the array will always be in cache since we're reading immediately after allocating, making reads even faster). That being said, we can see in the other tests that allocating arrays is probably comparable. It's definitely something that we can test though.

aspiwack commented 1 year ago

I take it back, I somewhat thought that the cast was outside of the unsafeCoerce. But it's inside the scope, and it's applying the coercion (co) from the unsafe coercion. It's perfectly normal and reasonable. It is likely to have a non-trivial cost, though. I'll see if I can deal with it. But it will be from within GHC anyway.

aspiwack commented 1 year ago

I have a Merge Request in flight for GHC (!10384). This Merge Request really improves the code generated by linear-base. Though it doesn't seem to affect benchmarks significantly.

When it lands, I want to play with :scream: inlining the alloc function for arrays (maybe check if there are some unsafePerformIO that could be unsafeDupablePerformIO as well). When that's done, I think we'll have done about the best we can for the quality of code generation, and I'll close this issue.