NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
271 stars 53 forks source link

Allow inlining past loop broadcasts #3416

Open jacobhinkle opened 6 days ago

jacobhinkle commented 6 days ago

Stacked on #3414

This PR enables us to inline an MmaOp properly when its inputs are missing broadcast dimensions. We do this by always allowing inlining past loop broadcasts or their transforms. For example

tv0:
  logical [ iS1{i0} ]
  loop [ iS1{i0} bS5{1} ]
tv1:
  logical [ iS2{i1} ]
  loop [ bS6{1} iS2{i1} ]
tv2 = foo(tv0, tv1)
  logical [ iS3{i0} iS4{i1} ]

As long as the operation foo properly maps its arguments despite the missing logical dimensions (as MmaOp does as of #3391), then we should be able to fully inline this case because the loop broadcasts bS5 and bS6 are imaginary in the sense that they don't impact indexing.

jacobhinkle commented 6 days ago

After this, we can actually generate a proper kernel and run it. I will rebase #3406 onto this and modify the test to compile and run in that PR so we can inspect the generated kernel there. We can keep this PR for discussing the inlining changes only.

naoyam commented 6 days ago

Does this only apply to broadcast IDs added by TensorView::broadcast()?

jacobhinkle commented 6 days ago

Does this only apply to broadcast IDs added by TensorView::broadcast()?

Yes, that's the intention. I am using tv->domain()->additionalIDs(), which I think is only those broadcasts?

naoyam commented 6 days ago

Does this only apply to broadcast IDs added by TensorView::broadcast()?

Yes, that's the intention. I am using tv->domain()->additionalIDs(), which I think is only those broadcasts?

Yes. @zasdfgbnm, when you added this, were you thinking about having non-broadcast IDs in additional_ids_?

jacobhinkle commented 6 days ago

Does this only apply to broadcast IDs added by TensorView::broadcast()?

Yes, that's the intention. I am using tv->domain()->additionalIDs(), which I think is only those broadcasts?

Yes. @zasdfgbnm, when you added this, were you thinking about having non-broadcast IDs in additional_ids_?

To be safe I'll check the IterType when skipping.

zasdfgbnm commented 3 days ago

Yes. @zasdfgbnm, when you added this, were you thinking about having non-broadcast IDs in additional_ids_?

No, I added it primarily for storing these new broadcasts.

jacobhinkle commented 3 days ago

In the latest pushed changes, I do a BFS from producer logical to producer allocation and from consumer root to consumer loop. This lets me collect the IDs that are used for indexing (assuming no shorter paths are discovered later). I then restrict the strictAreMapped check to the case where at least one of the producer or consumer ID is in that path. That covers loop broadcasts automatically as they're not used for indexing, and lets us inline around them if they appear in the same position as another ID that's not used in indexing that particular producer, as is the case for the mma use case I have in mind.