Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.21k stars 81 forks source link

Timeout for Platypus-30B and Thunder compile #294

Open mpatel31415 opened 7 months ago

mpatel31415 commented 7 months ago

🐛 Bug

With newest version of Docker image (tested on 2024-04-28 ) training with thunder.jit on 8xA100 it's not possible to run Platypus-30B and vicuna-33b-v1.3 models. This is the error:

Time to instantiate model: 0.05 seconds. /usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. self.pid = os.fork() Time to instantiate model: 0.05 seconds. Time to instantiate model: 0.05 seconds. Time to instantiate model: 0.05 seconds. Time to instantiate model: 0.05 seconds. /usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. self.pid = os.fork() /usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. self.pid = os.fork() /usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. self.pid = os.fork() /usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. self.pid = os.fork() /usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. self.pid = os.fork() /usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. self.pid = os.fork() [rank6]:[W429 12:49:14.458854201 ProcessGroupNCCL.cpp:1113] WARNING: process group has NOT been destroyed before it is being destructed. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL data transfers have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4 [rank3]:[E429 12:59:12.174771056 ProcessGroupNCCL.cpp:568] [Rank 3] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=721, OpType=_ALLGATHER_BASE, NumelIn=14909440, NumelOut=119275520, Timeout(ms)=600000) ran for 600059 milliseconds before timing out. [rank5]:[E429 12:59:12.230443027 ProcessGroupNCCL.cpp:568] [Rank 5] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=721, OpType=_ALLGATHER_BASE, NumelIn=14909440, NumelOut=119275520, Timeout(ms)=600000) ran for 600085 milliseconds before timing out. [rank1]:[E429 12:59:12.305260726 ProcessGroupNCCL.cpp:568] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=721, OpType=_ALLGATHER_BASE, NumelIn=14909440, NumelOut=119275520, Timeout(ms)=600000) ran for 600071 milliseconds before timing out. [rank4]:[E429 12:59:12.307015699 ProcessGroupNCCL.cpp:568] [Rank 4] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=721, OpType=_ALLGATHER_BASE, NumelIn=14909440, NumelOut=119275520, Timeout(ms)=600000) ran for 600032 milliseconds before timing out. [rank7]:[E429 12:59:12.327170052 ProcessGroupNCCL.cpp:568] [Rank 7] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=721, OpType=_ALLGATHER_BASE, NumelIn=14909440, NumelOut=119275520, Timeout(ms)=600000) ran for 600097 milliseconds before timing out. [rank2]:[E429 12:59:12.356285062 ProcessGroupNCCL.cpp:568] [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=721, OpType=_ALLGATHER_BASE, NumelIn=14909440, NumelOut=119275520, Timeout(ms)=600000) ran for 600025 milliseconds before timing out. [rank0]:[E429 12:59:12.369438320 ProcessGroupNCCL.cpp:568] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=721, OpType=_ALLGATHER_BASE, NumelIn=14909440, NumelOut=119275520, Timeout(ms)=600000) ran for 600021 milliseconds before timing out. [rank4]:[E429 12:59:13.082814283 ProcessGroupNCCL.cpp:1602] [PG 0 (default_pg) Rank 4] Timeout at NCCL work: 721, last enqueued NCCL work: 722, last completed NCCL work: 720. [rank4]:[E429 12:59:13.082874197 ProcessGroupNCCL.cpp:582] [Rank 4] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data. [rank4]:[E429 12:59:13.082891961 ProcessGroupNCCL.cpp:588] [Rank 4] To avoid data inconsistency, we are taking the entire process down. [rank4]:[E429 12:59:13.082929112 ProcessGroupNCCL.cpp:1432] [PG 0 (default_pg) Rank 4] Process group watchdog thread terminated with exception: [Rank 4] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=721, OpType=_ALLGATHER_BASE, NumelIn=14909440, NumelOut=119275520, Timeout(ms)=600000) ran for 600032 milliseconds before timing out. Exception raised from checkTimeout at /opt/pytorch/pytorch/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:570 (most recent call first): frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits, std::allocator >) + 0xae (0x7f5c2b47d84e in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)

To Reproduce

Before each testing each compilation method I restarted the container:

mkdir -p output
docker run --pull=always --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864  -v $PWD/output:/output -it INTERNAL_IMAGE:pjnl-20240427

Thunder


**Inductor**
* Works fine
* Command:

torchrun --nproc-per-node=8 /opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py --model_name Platypus-30B --compile inductor --distributed_mode fsdp --shard_mode zero3



### Expected behavior

If we can run model using Torch Inductor we should be able to use it Thunder as well.

### Environment

As in the Docker image. This results come from single A100. 
nvidia-smi output:
![image](https://github.com/Lightning-AI/lightning-thunder/assets/149149379/9d9d3972-02a4-4991-9a88-43a92c9a7aa1)

cc @carmocca @awaelchli @crcrpar
mruberry commented 6 months ago

triage review — @crcrpar would you take a look at this?

IvanYashchuk commented 5 months ago

@crcrpar, could you please take a look and tell us your findings on what's happening?

mruberry commented 4 months ago

ping @crcrpar

crcrpar commented 4 months ago

setting: 20240729 nightly image & 8 A100-SXM4-80GB devices,

Platypus-30B

If I tweak the number of layers to 36, it works. W/o the tweak, it fails due to out of memory even with TORCH_NCCL_AVOID_RECORD_STREAMS=1:

[ERROR    | nvfuser            ]: An error occurred while executing nvFuser FusionDefinition 22.
If you believe this is a bug or need assistance, please file an issue at https://github.com/NVIDIA/Fuser/issues/new
Here's a script to reproduce the error:
# CUDA devices:
#  0: NVIDIA A100-SXM4-80GB
#  0: NVIDIA A100-SXM4-80GB
#  0: NVIDIA A100-SXM4-80GB
#  0: NVIDIA A100-SXM4-80GB
#  0: NVIDIA A100-SXM4-80GB
#  0: NVIDIA A100-SXM4-80GB
#  0: NVIDIA A100-SXM4-80GB
#  0: NVIDIA A100-SXM4-80GB
# torch version: 2.5.0a0+git8927fc2
# cuda version: 12.6
# nvfuser version: 0.2.8+gitdd6886f
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id22(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[1, -1, -1], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T1 = fd.define_tensor(shape=[1, -1, -1], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T2 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T3 = fd.define_tensor(shape=[1, -1, -1], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T4 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T5 = fd.define_tensor(shape=[1, -1, -1], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T6 = fd.define_tensor(shape=[1, -1, -1], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T7 = fd.define_tensor(shape=[1, -1, 1], contiguity=[None, True, None], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T8 = fd.define_tensor(shape=[1, -1, -1], contiguity=[None, None, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T9 = fd.ops.cast(T6, dtype=DataType.Float)
    T10 = fd.ops.cast(T5, dtype=DataType.Float)
    T11 = fd.ops.add(T9, T10)
    S12 = fd.define_scalar(1, dtype=DataType.Int)
    S13 = fd.define_scalar(2048, dtype=DataType.Int)
    S14 = fd.define_scalar(6656, dtype=DataType.Int)
    V15 = fd.define_vector([S12, S13, S14], dtype=DataType.Int)
    T16 = fd.ops.broadcast_in_dim(T7, shape=V15, broadcast_dims=[0, 1, 2])
    T17 = fd.ops.mul(T11, T16)
    T18 = fd.ops.cast(T8, dtype=DataType.Float)
    T19 = fd.ops.cast(T0, dtype=DataType.Float)
    T20 = fd.ops.cast(T1, dtype=DataType.Float)
    T21 = fd.ops.cast(T3, dtype=DataType.Float)
    T22 = fd.ops.add(T20, T21)
    T23 = fd.ops.mul(T18, T22)
    T24 = fd.ops.mul(T17, T22)
    T25 = fd.ops.sum(T24, dims=[0, 1], keepdim=False, dtype=DataType.Null)
    T26 = fd.ops.mul(T16, T23)
    T27 = fd.ops.mul(T11, T23)
    T28 = fd.ops.sum(T27, dims=[0, 2], keepdim=False, dtype=DataType.Null)
    S29 = fd.define_scalar(1, dtype=DataType.Int)
    S30 = fd.define_scalar(2048, dtype=DataType.Int)
    S31 = fd.define_scalar(1, dtype=DataType.Int)
    V32 = fd.define_vector([S29, S30, S31], dtype=DataType.Int)
    T33 = fd.ops.broadcast_in_dim(T28, shape=V32, broadcast_dims=[1])
    S34 = fd.define_scalar(-0.500000, dtype=DataType.Double)
    T35 = fd.ops.mul(S34, T33)
    S36 = fd.define_scalar(3.00000, dtype=DataType.Double)
    T37 = fd.ops.pow(T7, S36)
    T38 = fd.ops.mul(T35, T37)
    S39 = fd.define_scalar(6656.00, dtype=DataType.Double)
    S40 = fd.ops.reciprocal(S39)
    T41 = fd.ops.mul(T38, S40)
    T42 = fd.ops.sum(T41, dims=[0, 2], keepdim=False, dtype=DataType.Null)
    S43 = fd.define_scalar(1, dtype=DataType.Int)
    S44 = fd.define_scalar(2048, dtype=DataType.Int)
    V45 = fd.define_vector([S43, S44], dtype=DataType.Int)
    T46 = fd.ops.broadcast_in_dim(T42, shape=V45, broadcast_dims=[1])
    S47 = fd.define_scalar(1, dtype=DataType.Int)
    S48 = fd.define_scalar(2048, dtype=DataType.Int)
    S49 = fd.define_scalar(1, dtype=DataType.Int)
    V50 = fd.define_vector([S47, S48, S49], dtype=DataType.Int)
    T51 = fd.ops.broadcast_in_dim(T46, shape=V50, broadcast_dims=[0, 1])
    S52 = fd.define_scalar(1, dtype=DataType.Int)
    S53 = fd.define_scalar(2048, dtype=DataType.Int)
    S54 = fd.define_scalar(6656, dtype=DataType.Int)
    V55 = fd.define_vector([S52, S53, S54], dtype=DataType.Int)
    T56 = fd.ops.broadcast_in_dim(T51, shape=V55, broadcast_dims=[0, 1, 2])
    T57 = fd.ops.mul(T11, T56)
    T58 = fd.ops.add(T26, T57)
    T59 = fd.ops.add(T58, T57)
    T60 = fd.ops.add(T19, T59)
    T61 = fd.ops.cast(T60, dtype=DataType.BFloat16)
    S62 = fd.define_scalar(8.00000, dtype=DataType.Double)
    S63 = fd.ops.reciprocal(S62)
    T64 = fd.ops.mul(T25, S63)
    T65 = fd.ops.cast(T64, dtype=DataType.BFloat16)
    T66 = fd.ops.cast(T2, dtype=DataType.Float)
    S67 = fd.define_scalar(8.00000, dtype=DataType.Double)
    S68 = fd.ops.reciprocal(S67)
    T69 = fd.ops.mul(T66, S68)
    T70 = fd.ops.cast(T69, dtype=DataType.BFloat16)
    T71 = fd.ops.cast(T4, dtype=DataType.Float)
    S72 = fd.define_scalar(8.00000, dtype=DataType.Double)
    S73 = fd.ops.reciprocal(S72)
    T74 = fd.ops.mul(T71, S73)
    T75 = fd.ops.cast(T74, dtype=DataType.BFloat16)
    fd.add_output(T61)
    fd.add_output(T65)
    fd.add_output(T70)
    fd.add_output(T75)

with FusionDefinition() as fd:
    nvfuser_fusion_id22(fd)

inputs = [
    torch.randn((13631488,), dtype=torch.bfloat16, device='cuda:6').as_strided((1, 2048, 6656), (13631488, 6656, 1)),
    torch.randn((13631488,), dtype=torch.bfloat16, device='cuda:6').as_strided((1, 2048, 6656), (13631488, 6656, 1)),
    torch.randn((119275520,), dtype=torch.bfloat16, device='cuda:6').as_strided((17920, 6656), (6656, 1)),
    torch.randn((13631488,), dtype=torch.bfloat16, device='cuda:6').as_strided((1, 2048, 6656), (13631488, 6656, 1)),
    torch.randn((119275520,), dtype=torch.bfloat16, device='cuda:6').as_strided((17920, 6656), (6656, 1)),
    torch.randn((13631488,), dtype=torch.bfloat16, device='cuda:6').as_strided((1, 2048, 6656), (13631488, 6656, 1)),
    torch.randn((13631488,), dtype=torch.bfloat16, device='cuda:6').as_strided((1, 2048, 6656), (13631488, 6656, 1)),
    torch.randn((2048,), dtype=torch.float32, device='cuda:6').as_strided((1, 2048, 1), (2048, 1, 1)),
    torch.randn((6656,), dtype=torch.bfloat16, device='cuda:6').as_strided((1, 2048, 6656), (6656, 0, 1)),
]
fd.execute(inputs)
Traceback (most recent call last):
  File "/opt/pytorch/nvfuser/nvfuser/__init__.py", line 210, in execute
    result = self._execute(
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 228.00 MiB. GPU 6 has a total capacity of 79.14 GiB of which 192.75 MiB is free. Process 458726 has 78.94 GiB memory in use. Of the allocated memory 76.01 GiB is allocated by PyTorch, and 1.83 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

vicuna-33b-v1.3

ditto. With 36 layers, it works. Otherwise, it leads to OOM.

[ERROR    | nvfuser            ]: An error occurred while executing nvFuser FusionDefinition 19.
If you believe this is a bug or need assistance, please file an issue at https://github.com/NVIDIA/Fuser/issues/new
Here's a script to reproduce the error:
# CUDA devices:
#  0: NVIDIA A100-SXM4-80GB
#  0: NVIDIA A100-SXM4-80GB
#  0: NVIDIA A100-SXM4-80GB
#  0: NVIDIA A100-SXM4-80GB
#  0: NVIDIA A100-SXM4-80GB
#  0: NVIDIA A100-SXM4-80GB
#  0: NVIDIA A100-SXM4-80GB
#  0: NVIDIA A100-SXM4-80GB
# torch version: 2.5.0a0+git8927fc2
# cuda version: 12.6
# nvfuser version: 0.2.8+gitdd6886f
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id19(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[1, -1, -1], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T1 = fd.define_tensor(shape=[1, -1, -1], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T2 = fd.define_tensor(shape=[1, -1, 1], contiguity=[None, True, None], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T3 = fd.define_tensor(shape=[1, -1, -1], contiguity=[None, None, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T4 = fd.define_tensor(shape=[1, -1, -1], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T5 = fd.define_tensor(shape=[1, -1, -1], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T6 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T7 = fd.define_tensor(shape=[1, -1, -1], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T8 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T9 = fd.ops.cast(T0, dtype=DataType.Float)
    T10 = fd.ops.cast(T1, dtype=DataType.Float)
    T11 = fd.ops.add(T10, T9)
    S12 = fd.define_scalar(1, dtype=DataType.Int)
    S13 = fd.define_scalar(2048, dtype=DataType.Int)
    S14 = fd.define_scalar(6656, dtype=DataType.Int)
    V15 = fd.define_vector([S12, S13, S14], dtype=DataType.Int)
    T16 = fd.ops.broadcast_in_dim(T2, shape=V15, broadcast_dims=[0, 1, 2])
    T17 = fd.ops.mul(T11, T16)
    T18 = fd.ops.cast(T3, dtype=DataType.Float)
    T19 = fd.ops.cast(T4, dtype=DataType.Float)
    T20 = fd.ops.cast(T5, dtype=DataType.Float)
    T21 = fd.ops.cast(T7, dtype=DataType.Float)
    T22 = fd.ops.add(T20, T21)
    T23 = fd.ops.mul(T18, T22)
    T24 = fd.ops.mul(T17, T22)
    T25 = fd.ops.sum(T24, dims=[0, 1], keepdim=False, dtype=DataType.Null)
    T26 = fd.ops.mul(T16, T23)
    T27 = fd.ops.mul(T11, T23)
    T28 = fd.ops.sum(T27, dims=[0, 2], keepdim=False, dtype=DataType.Null)
    S29 = fd.define_scalar(1, dtype=DataType.Int)
    S30 = fd.define_scalar(2048, dtype=DataType.Int)
    S31 = fd.define_scalar(1, dtype=DataType.Int)
    V32 = fd.define_vector([S29, S30, S31], dtype=DataType.Int)
    T33 = fd.ops.broadcast_in_dim(T28, shape=V32, broadcast_dims=[1])
    S34 = fd.define_scalar(-0.500000, dtype=DataType.Double)
    T35 = fd.ops.mul(S34, T33)
    S36 = fd.define_scalar(3.00000, dtype=DataType.Double)
    T37 = fd.ops.pow(T2, S36)
    T38 = fd.ops.mul(T35, T37)
    S39 = fd.define_scalar(6656.00, dtype=DataType.Double)
    S40 = fd.ops.reciprocal(S39)
    T41 = fd.ops.mul(T38, S40)
    T42 = fd.ops.sum(T41, dims=[0, 2], keepdim=False, dtype=DataType.Null)
    S43 = fd.define_scalar(1, dtype=DataType.Int)
    S44 = fd.define_scalar(2048, dtype=DataType.Int)
    V45 = fd.define_vector([S43, S44], dtype=DataType.Int)
    T46 = fd.ops.broadcast_in_dim(T42, shape=V45, broadcast_dims=[1])
    S47 = fd.define_scalar(1, dtype=DataType.Int)
    S48 = fd.define_scalar(2048, dtype=DataType.Int)
    S49 = fd.define_scalar(1, dtype=DataType.Int)
    V50 = fd.define_vector([S47, S48, S49], dtype=DataType.Int)
    T51 = fd.ops.broadcast_in_dim(T46, shape=V50, broadcast_dims=[0, 1])
    S52 = fd.define_scalar(1, dtype=DataType.Int)
    S53 = fd.define_scalar(2048, dtype=DataType.Int)
    S54 = fd.define_scalar(6656, dtype=DataType.Int)
    V55 = fd.define_vector([S52, S53, S54], dtype=DataType.Int)
    T56 = fd.ops.broadcast_in_dim(T51, shape=V55, broadcast_dims=[0, 1, 2])
    T57 = fd.ops.mul(T11, T56)
    T58 = fd.ops.add(T26, T57)
    T59 = fd.ops.add(T58, T57)
    T60 = fd.ops.add(T19, T59)
    T61 = fd.ops.cast(T60, dtype=DataType.BFloat16)
    S62 = fd.define_scalar(8.00000, dtype=DataType.Double)
    S63 = fd.ops.reciprocal(S62)
    T64 = fd.ops.mul(T25, S63)
    T65 = fd.ops.cast(T64, dtype=DataType.BFloat16)
    T66 = fd.ops.cast(T6, dtype=DataType.Float)
    S67 = fd.define_scalar(8.00000, dtype=DataType.Double)
    S68 = fd.ops.reciprocal(S67)
    T69 = fd.ops.mul(T66, S68)
    T70 = fd.ops.cast(T69, dtype=DataType.BFloat16)
    T71 = fd.ops.cast(T8, dtype=DataType.Float)
    S72 = fd.define_scalar(8.00000, dtype=DataType.Double)
    S73 = fd.ops.reciprocal(S72)
    T74 = fd.ops.mul(T71, S73)
    T75 = fd.ops.cast(T74, dtype=DataType.BFloat16)
    fd.add_output(T61)
    fd.add_output(T65)
    fd.add_output(T70)
    fd.add_output(T75)

with FusionDefinition() as fd:
    nvfuser_fusion_id19(fd)

inputs = [
    torch.randn((13631488,), dtype=torch.bfloat16, device='cuda:2').as_strided((1, 2048, 6656), (13631488, 6656, 1)),
    torch.randn((13631488,), dtype=torch.bfloat16, device='cuda:2').as_strided((1, 2048, 6656), (13631488, 6656, 1)),
    torch.randn((2048,), dtype=torch.float32, device='cuda:2').as_strided((1, 2048, 1), (2048, 1, 1)),
    torch.randn((6656,), dtype=torch.bfloat16, device='cuda:2').as_strided((1, 2048, 6656), (6656, 0, 1)),
    torch.randn((13631488,), dtype=torch.bfloat16, device='cuda:2').as_strided((1, 2048, 6656), (13631488, 6656, 1)),
    torch.randn((13631488,), dtype=torch.bfloat16, device='cuda:2').as_strided((1, 2048, 6656), (13631488, 6656, 1)),
    torch.randn((119275520,), dtype=torch.bfloat16, device='cuda:2').as_strided((17920, 6656), (6656, 1)),
    torch.randn((13631488,), dtype=torch.bfloat16, device='cuda:2').as_strided((1, 2048, 6656), (13631488, 6656, 1)),
    torch.randn((119275520,), dtype=torch.bfloat16, device='cuda:2').as_strided((17920, 6656), (6656, 1)),
]
fd.execute(inputs)
Traceback (most recent call last):
  File "/opt/pytorch/nvfuser/nvfuser/__init__.py", line 210, in execute
    result = self._execute(
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 228.00 MiB. GPU 2 has a total capacity of 79.14 GiB of which 8.75 MiB is free. Process 425238 has 79.12 GiB memory in use. Of the allocated memory 70.27 GiB is allocated by PyTorch, and 7.75 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)