csarofeen / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
http://pytorch.org
Other
26 stars 7 forks source link

`views` in HF-Bart Self-Attention #2456

Open kevinstephano opened 1 year ago

kevinstephano commented 1 year ago

🐛 Describe the bug

I have a horizontal fusion situation with reshape that I would like to understand if this can be fused. I think we have a knob to turn this on or a place to switch this. Jie might know. It would be good if this could be 1 kernel.

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id27(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Float, is_cpu=False)
    T1 = fd.define_tensor(symbolic_sizes=[-1, -1, -1, -1], contiguous=[True, True, True, True], dtype=DataType.Float, is_cpu=False)
    T2 = fd.define_tensor(symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Float, is_cpu=False)
    T3 = fd.ops.reshape(T2, original_shape=[8, 1024, 1024], new_shape=[8, 1024, 16, 64])
    T4 = fd.ops.permute(T3, dims=[0, 2, 1, 3])
    T5 = fd.ops.reshape(T0, original_shape=[8, 1024, 1024], new_shape=[8, 1024, 16, 64])
    T6 = fd.ops.permute(T5, dims=[0, 2, 1, 3])
    T7 = fd.ops.reshape(T6, original_shape=[8, 16, 1024, 64], new_shape=[128, 1024, 64])
    T8 = fd.ops.reshape(T1, original_shape=[8, 16, 1024, 64], new_shape=[128, 1024, 64])
    T9 = fd.ops.reshape(T4, original_shape=[8, 16, 1024, 64], new_shape=[128, 1024, 64])
    T10 = fd.ops.permute(T8, dims=[0, 2, 1])
    fd.add_output(T9)
    fd.add_output(T7)
    fd.add_output(T10)

inputs = [
    torch.randn(8, 1024, 1024, device='cuda'),
    torch.randn(8, 1024, 16, 64, device='cuda'),
    torch.randn(8, 1024, 1024, device='cuda'),
]

with FusionDefinition() as fd:
    nvfuser_fusion_id27(fd)

for _ in range(5):
    out = fd.execute(inputs)

Second case looks okay. Could you just double check it is okay with FP16 inputs?

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id21(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Float, is_cpu=False)
    T1 = fd.define_tensor(symbolic_sizes=[-1, 1, 1, -1], contiguous=[True, True, True, True], dtype=DataType.Float, is_cpu=False)
    T2 = fd.ops.reshape(T0, original_shape=[128, 1024, 1024], new_shape=[8, 16, 1024, 1024])
    T3 = fd.ops.broadcast_in_dim(T1, output_shape=[8, 16, 1024, 1024], broadcast_dims=[0, 1, 2, 3])
    T4 = fd.ops.add(T2, T3)
    T5 = fd.ops.reshape(T4, original_shape=[8, 16, 1024, 1024], new_shape=[128, 1024, 1024])
    T6 = fd.ops.max(T5, axes=[2], keepdim=False, dtype=DataType.Null)
    T7 = fd.ops.broadcast_in_dim(T6, output_shape=[128, 1024, 1], broadcast_dims=[0, 1])
    T8 = fd.ops.broadcast_in_dim(T7, output_shape=[128, 1024, 1024], broadcast_dims=[0, 1, 2])
    T9 = fd.ops.sub(T5, T8)
    T10 = fd.ops.exp(T9)
    T11 = fd.ops.sum(T10, axes=[2], keepdim=False, dtype=DataType.Null)
    T12 = fd.ops.broadcast_in_dim(T11, output_shape=[128, 1024, 1], broadcast_dims=[0, 1])
    T13 = fd.ops.broadcast_in_dim(T12, output_shape=[128, 1024, 1024], broadcast_dims=[0, 1, 2])
    T14 = fd.ops.div(T10, T13)
    fd.add_output(T14)

inputs = [
    torch.randn(128, 1024, 1024, device='cuda'),
    torch.randn(8, 1, 1, 1024, device='cuda'),
]

with FusionDefinition() as fd:
    nvfuser_fusion_id21(fd)

for _ in range(5):
    out = fd.execute(inputs)

Versions

TOT

jacobhinkle commented 1 year ago

This is a good test case. I think I know where the heuristic fails. This is probably related to https://github.com/csarofeen/pytorch/pull/2455

jacobhinkle commented 1 year ago

Despite having two reshapes, the second case produces 1 kernel with either float or half inputs. I'm not sure how that is happening since there are two reshapes, so it matches the "comment out C" pattern from https://github.com/csarofeen/pytorch/issues/2090#issuecomment-1398665847.

In the first case the segmenter is refusing to merge across the three connected components. I'm not sure this is due to reshapes: I don't think this is ever done: see this comment: https://github.com/csarofeen/pytorch/blob/devel/third_party/nvfuser/csrc/fusion_segmenter.cpp#L3275-L3279 For this particular case since the three groups are independent, wouldn't three kernels actually be preferable?