haskell / containers

Assorted concrete container types
https://hackage.haskell.org/package/containers
315 stars 177 forks source link

Rewrite rules for IntMap: intersectionWithKey/mapWithKey #1014

Open ruifengx opened 1 month ago

ruifengx commented 1 month ago

I checked the source code and did not see these rules defined, but I imagine they should be beneficial for performance. Below are the rewrite rules I mean:

{-# RULES
"intersectionWith/mapWithKey/1"
  forall f g m n
  . intersectionWithKey f (mapWithKey g m) n
  = intersectionWithKey (\i a b -> f i (g i a) b) m n
"intersectionWith/mapWithKey/2"
  forall f g m n
  . intersectionWithKey f m (mapWithKey g n)
  = intersectionWithKey (\i a b -> f i a (g i b)) m n
#-}

In theory, they should be able to eliminate one intermediate IntMap, but I do not know enough of the internals of this library nor GHC optimisation passes to be sure.

meooow25 commented 1 month ago

They would be beneficial, but a problem with such rules is the sheer number of them required to cover all relevant combinations of functions.

ruifengx commented 1 month ago

I see. I opened this issue because this pattern appears as part of my Generic-based DerivingVia wrapper, and I heard it is generally a bad practice to define orphan rewrite rules. Since map is defined in terms of mapWith and in turn mapWithKey, and similarly intersection, intersectionWith, and intersectionWithKey, I'd expect only these two rules are necessary. If the same rules apply to union and difference, there are still only six rules, if I am not mistaken.

treeowl commented 1 month ago

We have many similar rules for some other structures in containers. Going for consistency makes sense, I think.

treeowl commented 1 month ago

If we add rules, it's important to add the appropriate inlining guidance and to test and benchmark carefully. By the way, can you show us your generic framework? There might be a way to change it to accomplish your performance goals more reliably than with rules.

ruifengx commented 1 month ago

By the way, can you show us your generic framework? There might be a way to change it to accomplish your performance goals more reliably than with rules.

Of course. It is just a personal project in early stages, but below is the relevant code. The idea is that each QueryItem of a gives an IntMap a, and we want to derive a query for aggregates by intersection of all the query result IntMaps.

Let's start with Query and QueryItem, my interface classes.

class Query a where
  query :: (a -> IO r) -> IO r

class QueryItem q where
  queryItem :: IO (QueryResult q)

Because it is based on intersections, the neutral element is the full set. I defined the following type with a distinguished constructor All for the full set as the neutral element (it should be significantly faster than actually generating the full set and then perform an intersection). It is an Applicative functor.

data QueryResult a
  = All (Int -> a)
  | Some (IntMap a)

instance Functor QueryResult where
  fmap f (All k)  = All (f . k)
  fmap f (Some m) = Some (M.map f m)

instance Applicative QueryResult where
  pure = All . const
  liftA2 (#) (All f)  (All g)  = All (liftA2 (#) f g)
  liftA2 (#) (All f)  (Some q) = Some (M.mapWithKey (\n b -> f n # b) q)
  liftA2 (#) (Some p) (All g)  = Some (M.mapWithKey (\n a -> a # g n) p)
  liftA2 (#) (Some p) (Some q) = Some (M.intersectionWith (#) p q)

Below is the DerivingVia wrapper type. I only want to support product types.

newtype Bundle a = Bundle (QueryResult a)

My original naïve version involves a lot of fmap @QueryResult (and therefore IntMap.map):

instance Query (Bundle a) where
  query k = gQuery >>= k . Bundle . fmap to

class GQuery f where
  gQuery :: IO (QueryResult (f p))

instance GQuery U1 where
  gQuery = pure U1

instance QueryItem a => GQuery (K1 i a) where
  gQuery = fmap K1 <$> queryItem

instance (GQuery f, GQuery g) => GQuery (f :*: g) where
  gQuery = liftA2 (:*:) <$> gQuery @f <*> gQuery @g

instance GQuery f => GQuery (M1 i c f) where
  gQuery = fmap M1 <$> gQuery @f

Notice the fmap to in query, it is isolated in an IntMap.map, and therefore I expect the whole Rep a structure not to be eliminated by inlining and to remain in the final program.

To optimise these instances, I defined a helper class Ctor and the following instances. I think it might be overcomplicated for the purpose, and it might not work as I expected, but if my reasoning is correct, the following definition flattens the binary tree of (:*:) into a flat list, and performs a left fold, starting with the neutral element All (const ctor) (like the pure f in pure f <*> mx_1 <*> ... <*> mx_n).

instance (Ctor a, GQuery (Rep a)) => Query (Bundle a) where
  query k = gQuery @(Rep a) (pure (ctor @a)) >>= getQueryResult >>= k . Bundle

class Ctor a where
  ctor :: PCtor (Rep a) a

-- The `ctor` function in this instance contains both the construction of `Rep a p`
-- and the conversion from `Rep a p` to `a` using `Data.Typeable.to`, so I expect
-- it to be eventually optimised down to the actual data constructor.
instance (Generic a, GCtor (Rep a)) => Ctor a where
  ctor = gCtor (to @a)

-- PCtor: constructor with a polymorphic result type `r`
type family PCtor (f :: Type -> Type) (r :: Type) :: Type where
  -- the `K1` and `M1` constructors are eliminated
  -- and the `:*:` binary constructor is flattened into a list
  PCtor U1 r = r
  PCtor (K1 i a) r = a -> r
  PCtor (f :*: g) r = PCtor f (PCtor g r)
  PCtor (M1 i c f) r = PCtor f r
  -- e.g., for pair `(a, b)`, `f = K1 i a :*: K1 i b` (ignoring all the `M1`)
  --   and `PCtor f r = a -> b -> r`
  -- if we substitute `r` with the actual type `(a, b)`
  -- we get the type signature `(,) :: a -> b -> (a, b)`
  -- and the function `ctor` should be optimised to `ctor = (,)`
  -- which is the goal of all these type-level programming

class GCtor f where
  gCtor :: (f p -> r) -> PCtor f r

instance GCtor U1 where
  gCtor k = k U1

instance GCtor (K1 i a) where
  gCtor k a = k (K1 a)

instance (GCtor f, GCtor g) => GCtor (f :*: g) where
  gCtor k = gCtor \f -> gCtor \g -> k (f :*: g)

instance GCtor f => GCtor (M1 i c f) where
  gCtor k = gCtor (k . M1)

class GQuery f where
  -- we take a partial result containing the constructor function
  -- and we apply that function to the `QueryResult` for `f` using `<*>`
  gQuery :: QueryResult (PCtor f r) -> IO (QueryResult r)

instance GQuery U1 where
  gQuery = pure

instance QueryItem a => GQuery (K1 i a) where
  -- for each entry we combine it with the previous results using (<*>) and therefore `IntMap.intersectionWith`
  gQuery m = (<*>) m <$> queryItem

instance (GQuery f, GQuery g) => GQuery (f :*: g) where
  -- this is the flattening from binary tree to linear fold
  gQuery = gQuery @f >=> gQuery @g

instance GQuery f => GQuery (M1 i c f) where
  gQuery = gQuery @f

Now with the "optimised" version with type-level programming, we get the following structure for each query:

pure ctor <*> queryItem <*> ... <*> queryItem

Which is essentially

IntMap.intersectionWith (...) (IntMap.map ctor queryItem) queryItem ...

and we see there is always a intersectionWith/map pattern when the Bundle contains at least two QueryItems.

meooow25 commented 1 month ago

Thanks for the explanation! I'm not too familiar with GHC Generics, but as I understand it the problem is that the final form looks like this

pure ctor <*> queryItem <*> ... <*> queryItem

where each queryItem is an All or a Some (in IO, but let's ignore that).
Combined with <*> this translates to a bunch of intersectionWith and mapWithKeys.


It is possible to explicitly fuse these functions using something like Yoneda/Coyoneda, without relying on rules.

The Coyoneda version looks like

data QueryResult a
  = All (Int -> a)
  | forall b. Some (Int -> b -> a) (IntMap b)

mkSome :: IntMap a -> QueryResult a
mkSome = Some (\_ -> id)

instance Functor QueryResult where
  fmap f (All k) = All (f . k)
  fmap f (Some k m) = Some (\n -> f . k n) m

instance Applicative QueryResult where
  pure = All . const
  liftA2 (#) (All f) (All g) = All (liftA2 (#) f g)
  liftA2 (#) (All f) (Some g q) = Some (\n b -> f n # g n b) q
  liftA2 (#) (Some f p) (All g) = Some (\n a -> f n a # g n) p
  liftA2 (#) (Some f p) (Some g q) = mkSome (M.intersectionWithKey (\n a b -> f n a # g n b) p q)

Does this work out for you?

ruifengx commented 1 month ago

It is possible to explicitly fuse these functions using something like Yoneda/Coyoneda, without relying on rules.

Thanks a lot! I was actually trying the Yoneda embedding and it worked pretty well for my specific use case described here. I realised later that I also need to support other merge strategies and the Yoneda embedding became a bit more involved (so I did not feel entirely satisfied about my current implementation and hesitated to report the progress here), but I believe I should be able to work out the details as I spend some more time pondering on the problem.

I also think it is better not to rely on rewrite rules because I always cannot wrap my mind around the delicate interactions between RULES, INLINE [n], and NOINLINE [n]. I will leave the decision on these rewrite rules to you experts, and I will keep this issue open for now. Feel free to close it if you decide these rules are not worth the hassle.

meooow25 commented 1 month ago

Thanks a lot! I was actually trying the Yoneda embedding and it worked pretty well for my specific use case described here. I realised later that I also need to support other merge strategies and the Yoneda embedding became a bit more involved (so I did not feel entirely satisfied about my current implementation and hesitated to report the progress here), but I believe I should be able to work out the details as I spend some more time pondering on the problem.

Nice, I hope it works in the end.

I also think it is better not to rely on rewrite rules because I always cannot wrap my mind around the delicate interactions between RULES, INLINE [n], and NOINLINE [n]. I will leave the decision on these rewrite rules to you experts, and I will keep this issue open for now. Feel free to close it if you decide these rules are not worth the hassle.

I am with you on this, it would get troublesome to test and maintain. We can leave this open in case other folks have opinions on this.