Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.18k stars 77 forks source link

nvfuser failure #514

Closed mruberry closed 4 months ago

mruberry commented 5 months ago
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id2(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 2, 1, 0])
    S1 = fd.define_scalar(4096, dtype=DataType.Int)
    S2 = fd.define_scalar(0, dtype=DataType.Int)
    S3 = fd.define_scalar(1, dtype=DataType.Int)
    T4 = fd.ops.iota(S1, S2, S3, dtype=DataType.Int)
    S5 = fd.define_scalar(4096, dtype=DataType.Int)
    S6 = fd.define_scalar(1, dtype=DataType.Int)
    V7 = fd.define_vector([S5, S6], dtype=DataType.Int)
    T8 = fd.ops.broadcast_in_dim(T4, shape=V7, broadcast_dims=[0])
    S9 = fd.define_scalar(4096, dtype=DataType.Int)
    S10 = fd.define_scalar(0, dtype=DataType.Int)
    S11 = fd.define_scalar(1, dtype=DataType.Int)
    T12 = fd.ops.iota(S9, S10, S11, dtype=DataType.Int)
    S13 = fd.define_scalar(1, dtype=DataType.Int)
    S14 = fd.define_scalar(4096, dtype=DataType.Int)
    V15 = fd.define_vector([S13, S14], dtype=DataType.Int)
    T16 = fd.ops.broadcast_in_dim(T12, shape=V15, broadcast_dims=[1])
    S17 = fd.define_scalar(0, dtype=DataType.Int)
    T18 = fd.ops.add(T8, S17)
    S19 = fd.define_scalar(4096, dtype=DataType.Int)
    S20 = fd.define_scalar(4096, dtype=DataType.Int)
    V21 = fd.define_vector([S19, S20], dtype=DataType.Int)
    T22 = fd.ops.broadcast_in_dim(T18, shape=V21, broadcast_dims=[0, 1])
    S23 = fd.define_scalar(4096, dtype=DataType.Int)
    S24 = fd.define_scalar(4096, dtype=DataType.Int)
    V25 = fd.define_vector([S23, S24], dtype=DataType.Int)
    T26 = fd.ops.broadcast_in_dim(T16, shape=V25, broadcast_dims=[0, 1])
    T27 = fd.ops.ge(T22, T26)
    S28 = fd.define_scalar(2, dtype=DataType.Int)
    S29 = fd.define_scalar(16, dtype=DataType.Int)
    S30 = fd.define_scalar(4096, dtype=DataType.Int)
    S31 = fd.define_scalar(4096, dtype=DataType.Int)
    V32 = fd.define_vector([S28, S29, S30, S31], dtype=DataType.Int)
    T33 = fd.ops.broadcast_in_dim(T27, shape=V32, broadcast_dims=[2, 3])
    S34 = fd.define_scalar(float("-inf"), dtype=DataType.Double)
    T35 = fd.ops.where(T33, T0, S34)
    T36 = fd.ops.cast(T35, dtype=DataType.Float)
    T37 = fd.ops.max(T36, dims=[3], keepdim=False, dtype=DataType.Null)
    S38 = fd.define_scalar(2, dtype=DataType.Int)
    S39 = fd.define_scalar(16, dtype=DataType.Int)
    S40 = fd.define_scalar(4096, dtype=DataType.Int)
    S41 = fd.define_scalar(1, dtype=DataType.Int)
    V42 = fd.define_vector([S38, S39, S40, S41], dtype=DataType.Int)
    T43 = fd.ops.broadcast_in_dim(T37, shape=V42, broadcast_dims=[0, 1, 2])
    S44 = fd.define_scalar(2, dtype=DataType.Int)
    S45 = fd.define_scalar(16, dtype=DataType.Int)
    S46 = fd.define_scalar(4096, dtype=DataType.Int)
    S47 = fd.define_scalar(4096, dtype=DataType.Int)
    V48 = fd.define_vector([S44, S45, S46, S47], dtype=DataType.Int)
    T49 = fd.ops.broadcast_in_dim(T43, shape=V48, broadcast_dims=[0, 1, 2, 3])
    T50 = fd.ops.sub(T36, T49)
    T51 = fd.ops.exp(T50)
    T52 = fd.ops.sum(T51, dims=[3], keepdim=False, dtype=DataType.Null)
    S53 = fd.define_scalar(2, dtype=DataType.Int)
    S54 = fd.define_scalar(16, dtype=DataType.Int)
    S55 = fd.define_scalar(4096, dtype=DataType.Int)
    S56 = fd.define_scalar(1, dtype=DataType.Int)
    V57 = fd.define_vector([S53, S54, S55, S56], dtype=DataType.Int)
    T58 = fd.ops.broadcast_in_dim(T52, shape=V57, broadcast_dims=[0, 1, 2])
    S59 = fd.define_scalar(2, dtype=DataType.Int)
    S60 = fd.define_scalar(16, dtype=DataType.Int)
    S61 = fd.define_scalar(4096, dtype=DataType.Int)
    S62 = fd.define_scalar(4096, dtype=DataType.Int)
    V63 = fd.define_vector([S59, S60, S61, S62], dtype=DataType.Int)
    T64 = fd.ops.broadcast_in_dim(T58, shape=V63, broadcast_dims=[0, 1, 2, 3])
    T65 = fd.ops.reciprocal(T64)
    T66 = fd.ops.mul(T51, T65)
    T67 = fd.ops.cast(T66, dtype=DataType.BFloat16)
    fd.add_output(T67)

with FusionDefinition() as fd:
    nvfuser_fusion_id2(fd)

inputs = [
    torch.randn((536870912,), dtype=torch.bfloat16, device='cuda:1').as_strided((2, 16, 4096, 4096), (268435456, 16777216, 4096, 1)),
]
fd.execute(inputs)

from https://github.com/Lightning-AI/lightning-thunder/issues/474

cc @tfogal

kevinstephano commented 5 months ago

I reran the code separately on nvFuser with cuda:0 and I am not seeing a failure.

mruberry commented 5 months ago

@mpatel31415 Is it possible that we need a different version of nvFuser?

mruberry commented 4 months ago

triage review — likely this failure is due to an unrelated OOM