NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
271 stars 53 forks source link

Failed to create/execute the same FusionDefinition multiple times. #3389

Open wujingyue opened 1 week ago

wujingyue commented 1 week ago

This is an incidental bug that doesn't block me. But I found it quite interesting and confusing. Apparently, I can't construct and execute the same FusionDefinition twice. The failure pattern is like:

fd = Model()
fd.execute(inputs)
fd = Model()
fd.execute(inputs)

The second execute failed with an index error as if the fusion state is not populated at all.

To repro, git checkout bug3389 or apply the following patch:

diff --git a/tests/python/test_multidevice.py b/tests/python/test_multidevice.py
index be6fdde3..b3c42361 100644
--- a/tests/python/test_multidevice.py
+++ b/tests/python/test_multidevice.py
@@ -244,6 +244,15 @@ def test_sdpa(mpi_test):
             .contiguous()[rank : rank + 1]
         )

+    fd = Model()
+    outs = fd.execute(
+        [
+            head_parallelize(q),
+            head_parallelize(k),
+            head_parallelize(v),
+            head_parallelize(out_grad),
+        ]
+    )
     fd = Model()
     outs = fd.execute(
         [

and then run:

_bn && mpirun -np 1 pytest tests/python/test_multidevice.py -k sdpa -s --only-mpi
tests/python/test_multidevice.py:257:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
nvfuser/__init__.py:161: in execute
    self.multidevice_schedule()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self =
def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[1, 2, 12, 1024, 64], contigu... T7, S4, S5, T8, T9, None)
    fd.add_output(T6)
    fd.add_output(T10)
    fd.add_output(T11)
    fd.add_output(T12)

    def multidevice_schedule(self) -> None:
        mesh = self.sched._create_device_mesh(range(d))
        for t in [self.q, self.k, self.v, self.out_grad]:
>           self.sched._set_device_mesh(t, mesh)
E           IndexError: vector::_M_range_check: __n (which is 0) >= this->size() (which is 0)

tests/python/test_multidevice.py:217: IndexError
IndexError: vector::_M_range_check: __n (which is 0) >= this->size() (which is 0)

Note -np 1. Despite being part of test_multidevice.py, only one GPU is required.