NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
250 stars 49 forks source link

crash in transformOutputFromAllocationToLogical: frontier.size() != logical.size() #2760

Closed tfogal closed 1 month ago

tfogal commented 1 month ago

This error comes up in the proxy model. Note that one needs to apply #2759 first else #2685 might mask this error.

Encountered this on a Grace Hopper node:

Traceback (most recent call last):
  File "/home/tfogal/env/lib/python3.10/site-packages/nvfuser/__init__.py", line 215, in execute
    result = self._execute(
RuntimeError: frontier.size() == logical.size() INTERNAL ASSERT FAILED at "/tmp/pip-req-build-uqgrzxrb/csrc/executor.cpp":901, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. 
Exception raised from transformOutputFromAllocationToLogical at /tmp/pip-req-build-uqgrzxrb/csrc/executor.cpp:901 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, char const*) + 0xe8 (0xf28760a68334 in /home/tfogal/env/lib/python3.10/site-packages/nvfuser/_C.cpython-310-aarch64-linux-gnu.so)
frame #1: <unknown function> + 0x3ce644 (0xf28760d5e644 in /home/tfogal/env/lib/python3.10/site-packages/nvfuser/_C.cpython-310-aarch64-linux-gnu.so)
frame #2: <unknown function> + 0x3d092c (0xf28760d6092c in /home/tfogal/env/lib/python3.10/site-packages/nvfuser/_C.cpython-310-aarch64-linux-gnu.so)
frame #3: nvfuser::FusionKernelRuntime::getMaybeHeuristicsFor(nvfuser::KernelArgumentHolder const&, std::optional<nvfuser::PrimDataType>) + 0x514 (0xf28760f5c2c4 in /home/tfogal/env/lib/python3.10/site-packages/nvfuser/_C.cpython-310-aarch64-linux-gnu.so)
frame #4: <unknown function> + 0x5d354c (0xf28760f6354c in /home/tfogal/env/lib/python3.10/site-packages/nvfuser/_C.cpython-310-aarch64-linux-gnu.so)
frame #5: <unknown function> + 0x5d45b4 (0xf28760f645b4 in /home/tfogal/env/lib/python3.10/site-packages/nvfuser/_C.cpython-310-aarch64-linux-gnu.so)
frame #6: nvfuser::FusionExecutorCache::runFusionWithInputs(c10::ArrayRef<c10::IValue> const&, std::optional<nvfuser::PrimDataType>, std::optional<signed char>) + 0xb8 (0xf28760f67488 in /home/tfogal/env/lib/python3.10/site-packages/nvfuser/_C.cpython-310-aarch64-linux-gnu.so)
frame #7: nvfuser::python_frontend::FusionDefinition::execute(c10::ArrayRef<c10::IValue> const&, std::optional<signed char>, bool, bool, bool) const + 0x304 (0xf28761144ee4 in /home/tfogal/env/lib/python3.10/site-packages/nvfuser/_C.cpython-310-aarch64-linux-gnu.so)
frame #8: <unknown function> + 0xec3bc (0xf28760a7c3bc in /home/tfogal/env/lib/python3.10/site-packages/nvfuser/_C.cpython-310-aarch64-linux-gnu.so)
frame #9: <unknown function> + 0x163570 (0xf28760af3570 in /home/tfogal/env/lib/python3.10/site-packages/nvfuser/_C.cpython-310-aarch64-linux-gnu.so)
frame #10: <unknown function> + 0x1fcac0 (0xf28760b8cac0 in /home/tfogal/env/lib/python3.10/site-packages/nvfuser/_C.cpython-310-aarch64-linux-gnu.so)
<omitting python frames>
frame #26: <unknown function> + 0x273fc (0xf28a1e6d73fc in /usr/lib/aarch64-linux-gnu/libc.so.6)
frame #27: __libc_start_main + 0x98 (0xf28a1e6d74cc in /usr/lib/aarch64-linux-gnu/libc.so.6)

Traceback (most recent call last):
  File "/home/tfogal/dev/proxy/transform-logical/xform-logical.py", line 547, in <module>
    fd.execute(inputs)
  File "/home/tfogal/env/lib/python3.10/site-packages/nvfuser/__init__.py", line 215, in execute
    result = self._execute(
RuntimeError: frontier.size() == logical.size() INTERNAL ASSERT FAILED at "/tmp/pip-req-build-uqgrzxrb/csrc/executor.cpp":901, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. 
Exception raised from transformOutputFromAllocationToLogical at /tmp/pip-req-build-uqgrzxrb/csrc/executor.cpp:901 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, char const*) + 0xe8 (0xf28760a68334 in /home/tfogal/env/lib/python3.10/site-packages/nvfuser/_C.cpython-310-aarch64-linux-gnu.so)
frame #1: <unknown function> + 0x3ce644 (0xf28760d5e644 in /home/tfogal/env/lib/python3.10/site-packages/nvfuser/_C.cpython-310-aarch64-linux-gnu.so)
frame #2: <unknown function> + 0x3d092c (0xf28760d6092c in /home/tfogal/env/lib/python3.10/site-packages/nvfuser/_C.cpython-310-aarch64-linux-gnu.so)
frame #3: nvfuser::FusionKernelRuntime::getMaybeHeuristicsFor(nvfuser::KernelArgumentHolder const&, std::optional<nvfuser::PrimDataType>) + 0x514 (0xf28760f5c2c4 in /home/tfogal/env/lib/python3.10/site-packages/nvfuser/_C.cpython-310-aarch64-linux-gnu.so)
frame #4: <unknown function> + 0x5d354c (0xf28760f6354c in /home/tfogal/env/lib/python3.10/site-packages/nvfuser/_C.cpython-310-aarch64-linux-gnu.so)
frame #5: <unknown function> + 0x5d45b4 (0xf28760f645b4 in /home/tfogal/env/lib/python3.10/site-packages/nvfuser/_C.cpython-310-aarch64-linux-gnu.so)
frame #6: nvfuser::FusionExecutorCache::runFusionWithInputs(c10::ArrayRef<c10::IValue> const&, std::optional<nvfuser::PrimDataType>, std::optional<signed char>) + 0xb8 (0xf28760f67488 in /home/tfogal/env/lib/python3.10/site-packages/nvfuser/_C.cpython-310-aarch64-linux-gnu.so)
frame #7: nvfuser::python_frontend::FusionDefinition::execute(c10::ArrayRef<c10::IValue> const&, std::optional<signed char>, bool, bool, bool) const + 0x304 (0xf28761144ee4 in /home/tfogal/env/lib/python3.10/site-packages/nvfuser/_C.cpython-310-aarch64-linux-gnu.so)
frame #8: <unknown function> + 0xec3bc (0xf28760a7c3bc in /home/tfogal/env/lib/python3.10/site-packages/nvfuser/_C.cpython-310-aarch64-linux-gnu.so)
frame #9: <unknown function> + 0x163570 (0xf28760af3570 in /home/tfogal/env/lib/python3.10/site-packages/nvfuser/_C.cpython-310-aarch64-linux-gnu.so)
frame #10: <unknown function> + 0x1fcac0 (0xf28760b8cac0 in /home/tfogal/env/lib/python3.10/site-packages/nvfuser/_C.cpython-310-aarch64-linux-gnu.so)
<omitting python frames>
frame #26: <unknown function> + 0x273fc (0xf28a1e6d73fc in /usr/lib/aarch64-linux-gnu/libc.so.6)
frame #27: __libc_start_main + 0x98 (0xf28a1e6d74cc in /usr/lib/aarch64-linux-gnu/libc.so.6)

The reproducer is:

# torch version: 2.5.0a0+gita94e507
# nvfuser version: 0.2.8+git187b4ec
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id45(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Half, is_cpu=False, stride_order=[3, 2, 1, 0])
    T1 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T2 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T3 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Half, is_cpu=False, stride_order=[3, 2, 1, 0])
    T4 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Half, is_cpu=False, stride_order=[3, 2, 1, 0])
    T5 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Half, is_cpu=False, stride_order=[2, 3, 1, 0])
    T6 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Bool, is_cpu=False, stride_order=[3, 2, 1, 0])
    T7 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Bool, is_cpu=False, stride_order=[3, 2, 1, 0])
    T8 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Bool, is_cpu=False, stride_order=[3, 2, 1, 0])
    T9 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Bool, is_cpu=False, stride_order=[3, 2, 1, 0])
    T10 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[None, True, None, None], dtype=DataType.Float, is_cpu=False, stride_order=[3, 2, 1, 0])
    T11 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[3, 2, 1, 0])
    T12 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[None, True, None, None], dtype=DataType.Float, is_cpu=False, stride_order=[3, 2, 1, 0])
    T13 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[3, 2, 1, 0])
    T14 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T15 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T16 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[3, 2, 1, 0])
    T17 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Half, is_cpu=False, stride_order=[3, 2, 1, 0])
    T18 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Half, is_cpu=False, stride_order=[2, 3, 1, 0])
    T19 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Bool, is_cpu=False, stride_order=[3, 2, 1, 0])
    T20 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Bool, is_cpu=False, stride_order=[3, 2, 1, 0])
    T21 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Bool, is_cpu=False, stride_order=[3, 2, 1, 0])
    T22 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Bool, is_cpu=False, stride_order=[3, 2, 1, 0])
    T23 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[None, True, None, None], dtype=DataType.Float, is_cpu=False, stride_order=[3, 2, 1, 0])
    T24 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[3, 2, 1, 0])
    T25 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[None, True, None, None], dtype=DataType.Float, is_cpu=False, stride_order=[3, 2, 1, 0])
    T26 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[3, 2, 1, 0])
    T27 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T28 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T29 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[3, 2, 1, 0])
    T30 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Half, is_cpu=False, stride_order=[3, 2, 1, 0])
    T31 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Half, is_cpu=False, stride_order=[2, 3, 1, 0])
    T32 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[None, True, None, None], dtype=DataType.Float, is_cpu=False, stride_order=[3, 2, 1, 0])
    T33 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Half, is_cpu=False, stride_order=[2, 3, 1, 0])
    T34 = fd.ops.cast(T0, dtype=DataType.Float)
    S35 = fd.define_scalar(1, dtype=DataType.Int)
    S36 = fd.define_scalar(96, dtype=DataType.Int)
    S37 = fd.define_scalar(1, dtype=DataType.Int)
    S38 = fd.define_scalar(1, dtype=DataType.Int)
    V39 = fd.define_vector([S35, S36, S37, S38], dtype=DataType.Int)
    T40 = fd.ops.reshape(T1, new_shape=V39)
    S41 = fd.define_scalar(1, dtype=DataType.Int)
    S42 = fd.define_scalar(96, dtype=DataType.Int)
    S43 = fd.define_scalar(1, dtype=DataType.Int)
    S44 = fd.define_scalar(1, dtype=DataType.Int)
    V45 = fd.define_vector([S41, S42, S43, S44], dtype=DataType.Int)
    T46 = fd.ops.reshape(T2, new_shape=V45)
    S47 = fd.define_scalar(22, dtype=DataType.Int)
    S48 = fd.define_scalar(96, dtype=DataType.Int)
    S49 = fd.define_scalar(120, dtype=DataType.Int)
    S50 = fd.define_scalar(160, dtype=DataType.Int)
    V51 = fd.define_vector([S47, S48, S49, S50], dtype=DataType.Int)
    T52 = fd.ops.broadcast_in_dim(T46, shape=V51, broadcast_dims=[0, 1, 2, 3])
    T53 = fd.ops.sub(T34, T52)
    S54 = fd.define_scalar(22, dtype=DataType.Int)
    S55 = fd.define_scalar(96, dtype=DataType.Int)
    S56 = fd.define_scalar(120, dtype=DataType.Int)
    S57 = fd.define_scalar(160, dtype=DataType.Int)
    V58 = fd.define_vector([S54, S55, S56, S57], dtype=DataType.Int)
    T59 = fd.ops.broadcast_in_dim(T40, shape=V58, broadcast_dims=[0, 1, 2, 3])
    T60 = fd.ops.mul(T53, T59)
    T61 = fd.ops.ne(T3, T3)
    S62 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T63 = fd.ops.gt(T3, S62)
    S64 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T65 = fd.ops.where(T63, T3, S64)
    T66 = fd.ops.where(T61, T3, T65)
    T67 = fd.ops.ne(T66, T66)
    S68 = fd.define_scalar(6.00000, dtype=DataType.Double)
    T69 = fd.ops.lt(T66, S68)
    S70 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T71 = fd.ops.pad(T4, [-1, -1, -1, -1, 0, 0, 0, 0], S70)
    T72 = fd.ops.cast(T5, dtype=DataType.Float)
    S73 = fd.define_scalar(22, dtype=DataType.Int)
    S74 = fd.define_scalar(288, dtype=DataType.Int)
    S75 = fd.define_scalar(120, dtype=DataType.Int)
    S76 = fd.define_scalar(2, dtype=DataType.Int)
    S77 = fd.define_scalar(160, dtype=DataType.Int)
    S78 = fd.define_scalar(2, dtype=DataType.Int)
    V79 = fd.define_vector([S73, S74, S75, S76, S77, S78], dtype=DataType.Int)
    T80 = fd.ops.reshape(T71, new_shape=V79)
    T81 = fd.ops.cast(T80, dtype=DataType.Float)
    T82 = fd.ops.sum(T81, dims=[3], keepdim=False, dtype=DataType.Null)
    T83 = fd.ops.cast(T82, dtype=DataType.Half)
    S84 = fd.define_scalar(22, dtype=DataType.Int)
    S85 = fd.define_scalar(288, dtype=DataType.Int)
    S86 = fd.define_scalar(120, dtype=DataType.Int)
    S87 = fd.define_scalar(1, dtype=DataType.Int)
    S88 = fd.define_scalar(160, dtype=DataType.Int)
    S89 = fd.define_scalar(2, dtype=DataType.Int)
    V90 = fd.define_vector([S84, S85, S86, S87, S88, S89], dtype=DataType.Int)
    T91 = fd.ops.broadcast_in_dim(T83, shape=V90, broadcast_dims=[0, 1, 2, 4, 5])
    T92 = fd.ops.cast(T91, dtype=DataType.Float)
    T93 = fd.ops.sum(T92, dims=[3], keepdim=False, dtype=DataType.Null)
    T94 = fd.ops.sum(T93, dims=[4], keepdim=False, dtype=DataType.Null)
    T95 = fd.ops.cast(T94, dtype=DataType.Half)
    S96 = fd.define_scalar(22, dtype=DataType.Int)
    S97 = fd.define_scalar(288, dtype=DataType.Int)
    S98 = fd.define_scalar(120, dtype=DataType.Int)
    S99 = fd.define_scalar(160, dtype=DataType.Int)
    S100 = fd.define_scalar(1, dtype=DataType.Int)
    V101 = fd.define_vector([S96, S97, S98, S99, S100], dtype=DataType.Int)
    T102 = fd.ops.broadcast_in_dim(T95, shape=V101, broadcast_dims=[0, 1, 2, 3])
    T103 = fd.ops.cast(T102, dtype=DataType.Float)
    T104 = fd.ops.sum(T103, dims=[4], keepdim=False, dtype=DataType.Null)
    T105 = fd.ops.cast(T104, dtype=DataType.Half)
    S106 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T107 = fd.ops.where(T6, T105, S106)
    S108 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T109 = fd.ops.where(T6, S108, T105)
    S110 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T111 = fd.ops.where(T7, T109, S110)
    T112 = fd.ops.cast(T107, dtype=DataType.Float)
    T113 = fd.ops.cast(T111, dtype=DataType.Float)
    T114 = fd.ops.add(T112, T113)
    T115 = fd.ops.cast(T114, dtype=DataType.Half)
    S116 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T117 = fd.ops.where(T8, T115, S116)
    S118 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T119 = fd.ops.where(T8, S118, T115)
    S120 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T121 = fd.ops.where(T9, T119, S120)
    T122 = fd.ops.cast(T117, dtype=DataType.Float)
    T123 = fd.ops.cast(T121, dtype=DataType.Float)
    T124 = fd.ops.add(T122, T123)
    T125 = fd.ops.sum(T124, dims=[0, 2, 3], keepdim=False, dtype=DataType.Null)
    T126 = fd.ops.mul(T10, T124)
    T127 = fd.ops.mul(T11, T124)
    T128 = fd.ops.sum(T127, dims=[0, 2, 3], keepdim=False, dtype=DataType.Null)
    T129 = fd.ops.mul(T12, T126)
    T130 = fd.ops.mul(T13, T126)
    T131 = fd.ops.sum(T130, dims=[0, 2, 3], keepdim=False, dtype=DataType.Null)
    S132 = fd.define_scalar(1, dtype=DataType.Int)
    S133 = fd.define_scalar(288, dtype=DataType.Int)
    S134 = fd.define_scalar(1, dtype=DataType.Int)
    S135 = fd.define_scalar(1, dtype=DataType.Int)
    V136 = fd.define_vector([S132, S133, S134, S135], dtype=DataType.Int)
    T137 = fd.ops.broadcast_in_dim(T131, shape=V136, broadcast_dims=[1])
    T138 = fd.ops.neg(T129)
    T139 = fd.ops.sum(T138, dims=[0, 2, 3], keepdim=False, dtype=DataType.Null)
    S140 = fd.define_scalar(1, dtype=DataType.Int)
    S141 = fd.define_scalar(288, dtype=DataType.Int)
    S142 = fd.define_scalar(1, dtype=DataType.Int)
    S143 = fd.define_scalar(1, dtype=DataType.Int)
    V144 = fd.define_vector([S140, S141, S142, S143], dtype=DataType.Int)
    T145 = fd.ops.broadcast_in_dim(T139, shape=V144, broadcast_dims=[1])
    S146 = fd.define_scalar(288, dtype=DataType.Int)
    V147 = fd.define_vector([S146], dtype=DataType.Int)
    T148 = fd.ops.reshape(T145, new_shape=V147)
    S149 = fd.define_scalar(288, dtype=DataType.Int)
    V150 = fd.define_vector([S149], dtype=DataType.Int)
    T151 = fd.ops.reshape(T137, new_shape=V150)
    S152 = fd.define_scalar(-0.500000, dtype=DataType.Double)
    T153 = fd.ops.mul(S152, T151)
    S154 = fd.define_scalar(3.00000, dtype=DataType.Double)
    T155 = fd.ops.pow(T14, S154)
    T156 = fd.ops.mul(T153, T155)
    S157 = fd.define_scalar(1, dtype=DataType.Int)
    S158 = fd.define_scalar(288, dtype=DataType.Int)
    S159 = fd.define_scalar(1, dtype=DataType.Int)
    S160 = fd.define_scalar(1, dtype=DataType.Int)
    V161 = fd.define_vector([S157, S158, S159, S160], dtype=DataType.Int)
    T162 = fd.ops.broadcast_in_dim(T148, shape=V161, broadcast_dims=[1])
    S163 = fd.define_scalar(22, dtype=DataType.Int)
    S164 = fd.define_scalar(288, dtype=DataType.Int)
    S165 = fd.define_scalar(120, dtype=DataType.Int)
    S166 = fd.define_scalar(160, dtype=DataType.Int)
    V167 = fd.define_vector([S163, S164, S165, S166], dtype=DataType.Int)
    T168 = fd.ops.broadcast_in_dim(T162, shape=V167, broadcast_dims=[0, 1, 2, 3])
    S169 = fd.define_scalar(2.36742e-06, dtype=DataType.Double)
    T170 = fd.ops.mul(S169, T168)
    S171 = fd.define_scalar(1, dtype=DataType.Int)
    S172 = fd.define_scalar(288, dtype=DataType.Int)
    S173 = fd.define_scalar(1, dtype=DataType.Int)
    S174 = fd.define_scalar(1, dtype=DataType.Int)
    V175 = fd.define_vector([S171, S172, S173, S174], dtype=DataType.Int)
    T176 = fd.ops.broadcast_in_dim(T156, shape=V175, broadcast_dims=[1])
    S177 = fd.define_scalar(22, dtype=DataType.Int)
    S178 = fd.define_scalar(288, dtype=DataType.Int)
    S179 = fd.define_scalar(120, dtype=DataType.Int)
    S180 = fd.define_scalar(160, dtype=DataType.Int)
    V181 = fd.define_vector([S177, S178, S179, S180], dtype=DataType.Int)
    T182 = fd.ops.broadcast_in_dim(T176, shape=V181, broadcast_dims=[0, 1, 2, 3])
    S183 = fd.define_scalar(1, dtype=DataType.Int)
    S184 = fd.define_scalar(288, dtype=DataType.Int)
    S185 = fd.define_scalar(1, dtype=DataType.Int)
    S186 = fd.define_scalar(1, dtype=DataType.Int)
    V187 = fd.define_vector([S183, S184, S185, S186], dtype=DataType.Int)
    T188 = fd.ops.broadcast_in_dim(T15, shape=V187, broadcast_dims=[1])
    S189 = fd.define_scalar(22, dtype=DataType.Int)
    S190 = fd.define_scalar(288, dtype=DataType.Int)
    S191 = fd.define_scalar(120, dtype=DataType.Int)
    S192 = fd.define_scalar(160, dtype=DataType.Int)
    V193 = fd.define_vector([S189, S190, S191, S192], dtype=DataType.Int)
    T194 = fd.ops.broadcast_in_dim(T188, shape=V193, broadcast_dims=[0, 1, 2, 3])
    S195 = fd.define_scalar(2.00000, dtype=DataType.Double)
    T196 = fd.ops.mul(S195, T182)
    T197 = fd.ops.sub(T16, T194)
    T198 = fd.ops.mul(T196, T197)
    S199 = fd.define_scalar(422400., dtype=DataType.Double)
    S200 = fd.ops.reciprocal(S199)
    T201 = fd.ops.mul(T198, S200)
    T202 = fd.ops.add(T170, T201)
    T203 = fd.ops.add(T129, T202)
    T204 = fd.ops.cast(T203, dtype=DataType.Half)
    S205 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T206 = fd.ops.pad(T204, [0, 0, 0, 0, 0, 0, 0, 0], S205)
    S207 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T208 = fd.ops.pad(T17, [-1, -1, -1, -1, 0, 0, 0, 0], S207)
    T209 = fd.ops.cast(T18, dtype=DataType.Float)
    S210 = fd.define_scalar(22, dtype=DataType.Int)
    S211 = fd.define_scalar(192, dtype=DataType.Int)
    S212 = fd.define_scalar(120, dtype=DataType.Int)
    S213 = fd.define_scalar(2, dtype=DataType.Int)
    S214 = fd.define_scalar(160, dtype=DataType.Int)
    S215 = fd.define_scalar(2, dtype=DataType.Int)
    V216 = fd.define_vector([S210, S211, S212, S213, S214, S215], dtype=DataType.Int)
    T217 = fd.ops.reshape(T208, new_shape=V216)
    T218 = fd.ops.cast(T217, dtype=DataType.Float)
    T219 = fd.ops.sum(T218, dims=[3], keepdim=False, dtype=DataType.Null)
    T220 = fd.ops.cast(T219, dtype=DataType.Half)
    S221 = fd.define_scalar(22, dtype=DataType.Int)
    S222 = fd.define_scalar(192, dtype=DataType.Int)
    S223 = fd.define_scalar(120, dtype=DataType.Int)
    S224 = fd.define_scalar(1, dtype=DataType.Int)
    S225 = fd.define_scalar(160, dtype=DataType.Int)
    S226 = fd.define_scalar(2, dtype=DataType.Int)
    V227 = fd.define_vector([S221, S222, S223, S224, S225, S226], dtype=DataType.Int)
    T228 = fd.ops.broadcast_in_dim(T220, shape=V227, broadcast_dims=[0, 1, 2, 4, 5])
    T229 = fd.ops.cast(T228, dtype=DataType.Float)
    T230 = fd.ops.sum(T229, dims=[3], keepdim=False, dtype=DataType.Null)
    T231 = fd.ops.sum(T230, dims=[4], keepdim=False, dtype=DataType.Null)
    T232 = fd.ops.cast(T231, dtype=DataType.Half)
    S233 = fd.define_scalar(22, dtype=DataType.Int)
    S234 = fd.define_scalar(192, dtype=DataType.Int)
    S235 = fd.define_scalar(120, dtype=DataType.Int)
    S236 = fd.define_scalar(160, dtype=DataType.Int)
    S237 = fd.define_scalar(1, dtype=DataType.Int)
    V238 = fd.define_vector([S233, S234, S235, S236, S237], dtype=DataType.Int)
    T239 = fd.ops.broadcast_in_dim(T232, shape=V238, broadcast_dims=[0, 1, 2, 3])
    T240 = fd.ops.cast(T239, dtype=DataType.Float)
    T241 = fd.ops.sum(T240, dims=[4], keepdim=False, dtype=DataType.Null)
    T242 = fd.ops.cast(T241, dtype=DataType.Half)
    S243 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T244 = fd.ops.where(T19, T242, S243)
    S245 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T246 = fd.ops.where(T19, S245, T242)
    S247 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T248 = fd.ops.where(T20, T246, S247)
    T249 = fd.ops.cast(T244, dtype=DataType.Float)
    T250 = fd.ops.cast(T248, dtype=DataType.Float)
    T251 = fd.ops.add(T249, T250)
    T252 = fd.ops.cast(T251, dtype=DataType.Half)
    S253 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T254 = fd.ops.where(T21, T252, S253)
    S255 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T256 = fd.ops.where(T21, S255, T252)
    S257 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T258 = fd.ops.where(T22, T256, S257)
    T259 = fd.ops.cast(T254, dtype=DataType.Float)
    T260 = fd.ops.cast(T258, dtype=DataType.Float)
    T261 = fd.ops.add(T259, T260)
    T262 = fd.ops.sum(T261, dims=[0, 2, 3], keepdim=False, dtype=DataType.Null)
    T263 = fd.ops.mul(T23, T261)
    T264 = fd.ops.mul(T24, T261)
    T265 = fd.ops.sum(T264, dims=[0, 2, 3], keepdim=False, dtype=DataType.Null)
    T266 = fd.ops.mul(T25, T263)
    T267 = fd.ops.mul(T26, T263)
    T268 = fd.ops.sum(T267, dims=[0, 2, 3], keepdim=False, dtype=DataType.Null)
    S269 = fd.define_scalar(1, dtype=DataType.Int)
    S270 = fd.define_scalar(192, dtype=DataType.Int)
    S271 = fd.define_scalar(1, dtype=DataType.Int)
    S272 = fd.define_scalar(1, dtype=DataType.Int)
    V273 = fd.define_vector([S269, S270, S271, S272], dtype=DataType.Int)
    T274 = fd.ops.broadcast_in_dim(T268, shape=V273, broadcast_dims=[1])
    T275 = fd.ops.neg(T266)
    T276 = fd.ops.sum(T275, dims=[0, 2, 3], keepdim=False, dtype=DataType.Null)
    S277 = fd.define_scalar(1, dtype=DataType.Int)
    S278 = fd.define_scalar(192, dtype=DataType.Int)
    S279 = fd.define_scalar(1, dtype=DataType.Int)
    S280 = fd.define_scalar(1, dtype=DataType.Int)
    V281 = fd.define_vector([S277, S278, S279, S280], dtype=DataType.Int)
    T282 = fd.ops.broadcast_in_dim(T276, shape=V281, broadcast_dims=[1])
    S283 = fd.define_scalar(192, dtype=DataType.Int)
    V284 = fd.define_vector([S283], dtype=DataType.Int)
    T285 = fd.ops.reshape(T282, new_shape=V284)
    S286 = fd.define_scalar(192, dtype=DataType.Int)
    V287 = fd.define_vector([S286], dtype=DataType.Int)
    T288 = fd.ops.reshape(T274, new_shape=V287)
    S289 = fd.define_scalar(-0.500000, dtype=DataType.Double)
    T290 = fd.ops.mul(S289, T288)
    S291 = fd.define_scalar(3.00000, dtype=DataType.Double)
    T292 = fd.ops.pow(T27, S291)
    T293 = fd.ops.mul(T290, T292)
    S294 = fd.define_scalar(1, dtype=DataType.Int)
    S295 = fd.define_scalar(192, dtype=DataType.Int)
    S296 = fd.define_scalar(1, dtype=DataType.Int)
    S297 = fd.define_scalar(1, dtype=DataType.Int)
    V298 = fd.define_vector([S294, S295, S296, S297], dtype=DataType.Int)
    T299 = fd.ops.broadcast_in_dim(T285, shape=V298, broadcast_dims=[1])
    S300 = fd.define_scalar(22, dtype=DataType.Int)
    S301 = fd.define_scalar(192, dtype=DataType.Int)
    S302 = fd.define_scalar(120, dtype=DataType.Int)
    S303 = fd.define_scalar(160, dtype=DataType.Int)
    V304 = fd.define_vector([S300, S301, S302, S303], dtype=DataType.Int)
    T305 = fd.ops.broadcast_in_dim(T299, shape=V304, broadcast_dims=[0, 1, 2, 3])
    S306 = fd.define_scalar(2.36742e-06, dtype=DataType.Double)
    T307 = fd.ops.mul(S306, T305)
    S308 = fd.define_scalar(1, dtype=DataType.Int)
    S309 = fd.define_scalar(192, dtype=DataType.Int)
    S310 = fd.define_scalar(1, dtype=DataType.Int)
    S311 = fd.define_scalar(1, dtype=DataType.Int)
    V312 = fd.define_vector([S308, S309, S310, S311], dtype=DataType.Int)
    T313 = fd.ops.broadcast_in_dim(T293, shape=V312, broadcast_dims=[1])
    S314 = fd.define_scalar(22, dtype=DataType.Int)
    S315 = fd.define_scalar(192, dtype=DataType.Int)
    S316 = fd.define_scalar(120, dtype=DataType.Int)
    S317 = fd.define_scalar(160, dtype=DataType.Int)
    V318 = fd.define_vector([S314, S315, S316, S317], dtype=DataType.Int)
    T319 = fd.ops.broadcast_in_dim(T313, shape=V318, broadcast_dims=[0, 1, 2, 3])
    S320 = fd.define_scalar(1, dtype=DataType.Int)
    S321 = fd.define_scalar(192, dtype=DataType.Int)
    S322 = fd.define_scalar(1, dtype=DataType.Int)
    S323 = fd.define_scalar(1, dtype=DataType.Int)
    V324 = fd.define_vector([S320, S321, S322, S323], dtype=DataType.Int)
    T325 = fd.ops.broadcast_in_dim(T28, shape=V324, broadcast_dims=[1])
    S326 = fd.define_scalar(22, dtype=DataType.Int)
    S327 = fd.define_scalar(192, dtype=DataType.Int)
    S328 = fd.define_scalar(120, dtype=DataType.Int)
    S329 = fd.define_scalar(160, dtype=DataType.Int)
    V330 = fd.define_vector([S326, S327, S328, S329], dtype=DataType.Int)
    T331 = fd.ops.broadcast_in_dim(T325, shape=V330, broadcast_dims=[0, 1, 2, 3])
    S332 = fd.define_scalar(2.00000, dtype=DataType.Double)
    T333 = fd.ops.mul(S332, T319)
    T334 = fd.ops.sub(T29, T331)
    T335 = fd.ops.mul(T333, T334)
    S336 = fd.define_scalar(422400., dtype=DataType.Double)
    S337 = fd.ops.reciprocal(S336)
    T338 = fd.ops.mul(T335, S337)
    T339 = fd.ops.add(T307, T338)
    T340 = fd.ops.add(T266, T339)
    T341 = fd.ops.cast(T340, dtype=DataType.Half)
    S342 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T343 = fd.ops.pad(T341, [0, 0, 0, 0, 0, 0, 0, 0], S342)
    S344 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T345 = fd.ops.pad(T30, [-1, -1, -1, -1, 0, 0, 0, 0], S344)
    T346 = fd.ops.cast(T31, dtype=DataType.Float)
    S347 = fd.define_scalar(22, dtype=DataType.Int)
    S348 = fd.define_scalar(96, dtype=DataType.Int)
    S349 = fd.define_scalar(120, dtype=DataType.Int)
    S350 = fd.define_scalar(2, dtype=DataType.Int)
    S351 = fd.define_scalar(160, dtype=DataType.Int)
    S352 = fd.define_scalar(2, dtype=DataType.Int)
    V353 = fd.define_vector([S347, S348, S349, S350, S351, S352], dtype=DataType.Int)
    T354 = fd.ops.reshape(T345, new_shape=V353)
    T355 = fd.ops.cast(T354, dtype=DataType.Float)
    T356 = fd.ops.sum(T355, dims=[3], keepdim=False, dtype=DataType.Null)
    T357 = fd.ops.cast(T356, dtype=DataType.Half)
    S358 = fd.define_scalar(22, dtype=DataType.Int)
    S359 = fd.define_scalar(96, dtype=DataType.Int)
    S360 = fd.define_scalar(120, dtype=DataType.Int)
    S361 = fd.define_scalar(1, dtype=DataType.Int)
    S362 = fd.define_scalar(160, dtype=DataType.Int)
    S363 = fd.define_scalar(2, dtype=DataType.Int)
    V364 = fd.define_vector([S358, S359, S360, S361, S362, S363], dtype=DataType.Int)
    T365 = fd.ops.broadcast_in_dim(T357, shape=V364, broadcast_dims=[0, 1, 2, 4, 5])
    T366 = fd.ops.cast(T365, dtype=DataType.Float)
    T367 = fd.ops.sum(T366, dims=[3], keepdim=False, dtype=DataType.Null)
    T368 = fd.ops.sum(T367, dims=[4], keepdim=False, dtype=DataType.Null)
    T369 = fd.ops.cast(T368, dtype=DataType.Half)
    S370 = fd.define_scalar(22, dtype=DataType.Int)
    S371 = fd.define_scalar(96, dtype=DataType.Int)
    S372 = fd.define_scalar(120, dtype=DataType.Int)
    S373 = fd.define_scalar(160, dtype=DataType.Int)
    S374 = fd.define_scalar(1, dtype=DataType.Int)
    V375 = fd.define_vector([S370, S371, S372, S373, S374], dtype=DataType.Int)
    T376 = fd.ops.broadcast_in_dim(T369, shape=V375, broadcast_dims=[0, 1, 2, 3])
    T377 = fd.ops.cast(T376, dtype=DataType.Float)
    T378 = fd.ops.sum(T377, dims=[4], keepdim=False, dtype=DataType.Null)
    T379 = fd.ops.cast(T378, dtype=DataType.Half)
    S380 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T381 = fd.ops.where(T67, T379, S380)
    S382 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T383 = fd.ops.where(T67, S382, T379)
    S384 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T385 = fd.ops.where(T69, T383, S384)
    T386 = fd.ops.cast(T381, dtype=DataType.Float)
    T387 = fd.ops.cast(T385, dtype=DataType.Float)
    T388 = fd.ops.add(T386, T387)
    T389 = fd.ops.cast(T388, dtype=DataType.Half)
    S390 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T391 = fd.ops.where(T61, T389, S390)
    S392 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T393 = fd.ops.where(T61, S392, T389)
    S394 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T395 = fd.ops.where(T63, T393, S394)
    T396 = fd.ops.cast(T391, dtype=DataType.Float)
    T397 = fd.ops.cast(T395, dtype=DataType.Float)
    T398 = fd.ops.add(T396, T397)
    T399 = fd.ops.sum(T398, dims=[0, 2, 3], keepdim=False, dtype=DataType.Null)
    T400 = fd.ops.mul(T32, T398)
    T401 = fd.ops.mul(T60, T398)
    T402 = fd.ops.sum(T401, dims=[0, 2, 3], keepdim=False, dtype=DataType.Null)
    T403 = fd.ops.mul(T59, T400)
    T404 = fd.ops.mul(T53, T400)
    T405 = fd.ops.sum(T404, dims=[0, 2, 3], keepdim=False, dtype=DataType.Null)
    S406 = fd.define_scalar(1, dtype=DataType.Int)
    S407 = fd.define_scalar(96, dtype=DataType.Int)
    S408 = fd.define_scalar(1, dtype=DataType.Int)
    S409 = fd.define_scalar(1, dtype=DataType.Int)
    V410 = fd.define_vector([S406, S407, S408, S409], dtype=DataType.Int)
    T411 = fd.ops.broadcast_in_dim(T405, shape=V410, broadcast_dims=[1])
    T412 = fd.ops.neg(T403)
    T413 = fd.ops.sum(T412, dims=[0, 2, 3], keepdim=False, dtype=DataType.Null)
    S414 = fd.define_scalar(1, dtype=DataType.Int)
    S415 = fd.define_scalar(96, dtype=DataType.Int)
    S416 = fd.define_scalar(1, dtype=DataType.Int)
    S417 = fd.define_scalar(1, dtype=DataType.Int)
    V418 = fd.define_vector([S414, S415, S416, S417], dtype=DataType.Int)
    T419 = fd.ops.broadcast_in_dim(T413, shape=V418, broadcast_dims=[1])
    S420 = fd.define_scalar(96, dtype=DataType.Int)
    V421 = fd.define_vector([S420], dtype=DataType.Int)
    T422 = fd.ops.reshape(T419, new_shape=V421)
    S423 = fd.define_scalar(96, dtype=DataType.Int)
    V424 = fd.define_vector([S423], dtype=DataType.Int)
    T425 = fd.ops.reshape(T411, new_shape=V424)
    S426 = fd.define_scalar(-0.500000, dtype=DataType.Double)
    T427 = fd.ops.mul(S426, T425)
    S428 = fd.define_scalar(3.00000, dtype=DataType.Double)
    T429 = fd.ops.pow(T1, S428)
    T430 = fd.ops.mul(T427, T429)
    S431 = fd.define_scalar(1, dtype=DataType.Int)
    S432 = fd.define_scalar(96, dtype=DataType.Int)
    S433 = fd.define_scalar(1, dtype=DataType.Int)
    S434 = fd.define_scalar(1, dtype=DataType.Int)
    V435 = fd.define_vector([S431, S432, S433, S434], dtype=DataType.Int)
    T436 = fd.ops.broadcast_in_dim(T422, shape=V435, broadcast_dims=[1])
    S437 = fd.define_scalar(22, dtype=DataType.Int)
    S438 = fd.define_scalar(96, dtype=DataType.Int)
    S439 = fd.define_scalar(120, dtype=DataType.Int)
    S440 = fd.define_scalar(160, dtype=DataType.Int)
    V441 = fd.define_vector([S437, S438, S439, S440], dtype=DataType.Int)
    T442 = fd.ops.broadcast_in_dim(T436, shape=V441, broadcast_dims=[0, 1, 2, 3])
    S443 = fd.define_scalar(2.36742e-06, dtype=DataType.Double)
    T444 = fd.ops.mul(S443, T442)
    S445 = fd.define_scalar(1, dtype=DataType.Int)
    S446 = fd.define_scalar(96, dtype=DataType.Int)
    S447 = fd.define_scalar(1, dtype=DataType.Int)
    S448 = fd.define_scalar(1, dtype=DataType.Int)
    V449 = fd.define_vector([S445, S446, S447, S448], dtype=DataType.Int)
    T450 = fd.ops.broadcast_in_dim(T430, shape=V449, broadcast_dims=[1])
    S451 = fd.define_scalar(22, dtype=DataType.Int)
    S452 = fd.define_scalar(96, dtype=DataType.Int)
    S453 = fd.define_scalar(120, dtype=DataType.Int)
    S454 = fd.define_scalar(160, dtype=DataType.Int)
    V455 = fd.define_vector([S451, S452, S453, S454], dtype=DataType.Int)
    T456 = fd.ops.broadcast_in_dim(T450, shape=V455, broadcast_dims=[0, 1, 2, 3])
    S457 = fd.define_scalar(1, dtype=DataType.Int)
    S458 = fd.define_scalar(96, dtype=DataType.Int)
    S459 = fd.define_scalar(1, dtype=DataType.Int)
    S460 = fd.define_scalar(1, dtype=DataType.Int)
    V461 = fd.define_vector([S457, S458, S459, S460], dtype=DataType.Int)
    T462 = fd.ops.broadcast_in_dim(T2, shape=V461, broadcast_dims=[1])
    S463 = fd.define_scalar(22, dtype=DataType.Int)
    S464 = fd.define_scalar(96, dtype=DataType.Int)
    S465 = fd.define_scalar(120, dtype=DataType.Int)
    S466 = fd.define_scalar(160, dtype=DataType.Int)
    V467 = fd.define_vector([S463, S464, S465, S466], dtype=DataType.Int)
    T468 = fd.ops.broadcast_in_dim(T462, shape=V467, broadcast_dims=[0, 1, 2, 3])
    S469 = fd.define_scalar(2.00000, dtype=DataType.Double)
    T470 = fd.ops.mul(S469, T456)
    T471 = fd.ops.sub(T34, T468)
    T472 = fd.ops.mul(T470, T471)
    S473 = fd.define_scalar(422400., dtype=DataType.Double)
    S474 = fd.ops.reciprocal(S473)
    T475 = fd.ops.mul(T472, S474)
    T476 = fd.ops.add(T444, T475)
    T477 = fd.ops.add(T403, T476)
    T478 = fd.ops.cast(T477, dtype=DataType.Half)
    S479 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T480 = fd.ops.pad(T478, [0, 0, 0, 0, 0, 0, 0, 0], S479)
    T481 = fd.ops.cast(T33, dtype=DataType.Float)
    fd.add_output(T72)
    fd.add_output(T125)
    fd.add_output(T128)
    fd.add_output(T204)
    fd.add_output(T206)
    fd.add_output(T209)
    fd.add_output(T262)
    fd.add_output(T265)
    fd.add_output(T341)
    fd.add_output(T343)
    fd.add_output(T346)
    fd.add_output(T399)
    fd.add_output(T402)
    fd.add_output(T478)
    fd.add_output(T480)
    fd.add_output(T481)

with FusionDefinition() as fd:
    nvfuser_fusion_id45(fd)

inputs = [
    torch.randn((40550400,), dtype=torch.float16, device='cuda:0').as_strided((22, 96, 120, 160), (1843200, 19200, 160, 1)),
    torch.randn((96,), dtype=torch.float32, device='cuda:0').as_strided((96,), (1,)),
    torch.randn((96,), dtype=torch.float32, device='cuda:0').as_strided((96,), (1,)),
    torch.randn((40550400,), dtype=torch.float16, device='cuda:0').as_strided((22, 96, 120, 160), (1843200, 19200, 160, 1)),
    torch.randn((493726464,), dtype=torch.float16, device='cuda:0').as_strided((22, 288, 242, 322), (22442112, 77924, 322, 1)),
    torch.randn((746496,), dtype=torch.float16, device='cuda:0').as_strided((288, 288, 3, 3), (9, 2592, 3, 1)),
    torch.randint(0, 2, (121651200,), dtype=torch.bool, device='cuda:0').as_strided((22, 288, 120, 160), (5529600, 19200, 160, 1)),
    torch.randint(0, 2, (121651200,), dtype=torch.bool, device='cuda:0').as_strided((22, 288, 120, 160), (5529600, 19200, 160, 1)),
    torch.randint(0, 2, (121651200,), dtype=torch.bool, device='cuda:0').as_strided((22, 288, 120, 160), (5529600, 19200, 160, 1)),
    torch.randint(0, 2, (121651200,), dtype=torch.bool, device='cuda:0').as_strided((22, 288, 120, 160), (5529600, 19200, 160, 1)),
    torch.randn((288,), dtype=torch.float32, device='cuda:0').as_strided((22, 288, 120, 160), (0, 1, 0, 0)),
    torch.randn((121651200,), dtype=torch.float32, device='cuda:0').as_strided((22, 288, 120, 160), (5529600, 19200, 160, 1)),
    torch.randn((288,), dtype=torch.float32, device='cuda:0').as_strided((22, 288, 120, 160), (0, 1, 0, 0)),
    torch.randn((121651200,), dtype=torch.float32, device='cuda:0').as_strided((22, 288, 120, 160), (5529600, 19200, 160, 1)),
    torch.randn((288,), dtype=torch.float32, device='cuda:0').as_strided((288,), (1,)),
    torch.randn((288,), dtype=torch.float32, device='cuda:0').as_strided((288,), (1,)),
    torch.randn((121651200,), dtype=torch.float32, device='cuda:0').as_strided((22, 288, 120, 160), (5529600, 19200, 160, 1)),
    torch.randn((329150976,), dtype=torch.float16, device='cuda:0').as_strided((22, 192, 242, 322), (14961408, 77924, 322, 1)),
    torch.randn((331776,), dtype=torch.float16, device='cuda:0').as_strided((192, 192, 3, 3), (9, 1728, 3, 1)),
    torch.randint(0, 2, (81100800,), dtype=torch.bool, device='cuda:0').as_strided((22, 192, 120, 160), (3686400, 19200, 160, 1)),
    torch.randint(0, 2, (81100800,), dtype=torch.bool, device='cuda:0').as_strided((22, 192, 120, 160), (3686400, 19200, 160, 1)),
    torch.randint(0, 2, (81100800,), dtype=torch.bool, device='cuda:0').as_strided((22, 192, 120, 160), (3686400, 19200, 160, 1)),
    torch.randint(0, 2, (81100800,), dtype=torch.bool, device='cuda:0').as_strided((22, 192, 120, 160), (3686400, 19200, 160, 1)),
    torch.randn((192,), dtype=torch.float32, device='cuda:0').as_strided((22, 192, 120, 160), (0, 1, 0, 0)),
    torch.randn((81100800,), dtype=torch.float32, device='cuda:0').as_strided((22, 192, 120, 160), (3686400, 19200, 160, 1)),
    torch.randn((192,), dtype=torch.float32, device='cuda:0').as_strided((22, 192, 120, 160), (0, 1, 0, 0)),
    torch.randn((81100800,), dtype=torch.float32, device='cuda:0').as_strided((22, 192, 120, 160), (3686400, 19200, 160, 1)),
    torch.randn((192,), dtype=torch.float32, device='cuda:0').as_strided((192,), (1,)),
    torch.randn((192,), dtype=torch.float32, device='cuda:0').as_strided((192,), (1,)),
    torch.randn((81100800,), dtype=torch.float32, device='cuda:0').as_strided((22, 192, 120, 160), (3686400, 19200, 160, 1)),
    torch.randn((164575488,), dtype=torch.float16, device='cuda:0').as_strided((22, 96, 242, 322), (7480704, 77924, 322, 1)),
    torch.randn((82944,), dtype=torch.float16, device='cuda:0').as_strided((96, 96, 3, 3), (9, 864, 3, 1)),
    torch.randn((96,), dtype=torch.float32, device='cuda:0').as_strided((22, 96, 120, 160), (0, 1, 0, 0)),
    torch.randn((82944,), dtype=torch.float16, device='cuda:0').as_strided((96, 96, 3, 3), (9, 864, 3, 1)),
]
fd.execute(inputs)
liqiangxl commented 1 month ago

Seems because AllocationDomainPass set an allocation domain which has a bcast ID, then that bcast ID is removed in RemoveBcastSqueeze, alther the order of these two passes can solve the error. Will check other cases in CI.

tfogal commented 1 month ago

Closing; this is confusingly working just fine with new containers / new builds. We may have inadvertently addressed this one.