pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.71k stars 22.28k forks source link

Represent sparse tensors (layout) in traced graph #117188

Closed aartbik closed 2 months ago

aartbik commented 8 months ago

šŸš€ The feature, motivation and pitch

For background discussion, please see:

https://discuss.pytorch.org/t/connecting-pytorch-sparse-tensors-with-mlir/195145

Given the following code:

import torch
import torch.export
import torch.sparse

class BikNet(torch.nn.Module):

  def __init__(self):
    super(BikNet, self).__init__()
    return

  def forward(self, x):
    return x.sum()

biknet = BikNet()
biknet.eval()

dense_input = torch.ones(64, 64)
sparse_input = dense_input.to_sparse_csr()

prog1 = torch.export.export(biknet, args=(dense_input,))
prog2 = torch.export.export(biknet, args=(sparse_input,))   # fails

Then prog1 yields the expected traced graph IR as follows:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, l_x_: "f32[64, 64]"):
            # File: biknet.py:17, code: return x.sum()
            sum_1: "f32[]" = torch.ops.aten.sum.default(l_x_);  l_x_ = None
            return (sum_1,)

Graph signature: ExportGraphSignature(
  input_specs=[
    InputSpec(
    kind=<InputKind.USER_INPUT: 1>,
    arg=TensorArgument(name='l_x_'),
    target=None)
  ],
  output_specs=[
    OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>,
    arg=TensorArgument(name='sum_1'),
    target=None)
    ])
Range constraints: {}

However, computing prog2 throws an exception when trying to compute the tensor's size (due to having no strides on CSR format when trying to multiply shapes and strides for size):

torch._dynamo.exc.InternalTorchDynamoError: Sparse CSR tensors do not have strides

A much more desired outcome would be to yield a similar traced graph as for the dense case, but where argument l_x is annotated with the torch.sparse_csr layout. That would enable compilers (in particular "sparse compilers") to build an IR where sparsity is reflected in the types (and not the code yet).

Alternatives

No response

Additional context

No response

cc @alexsamardzic @nikitaved @pearu @cpuhrsch @amjames @bhosmer @jcaip @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @wconstab

aartbik commented 8 months ago

I hacked a bit around and was able to get a traced graph with the information I am (for now) looking for. In order to do this I had to

With that, I am able to generate the following traced graph. This will enable me to prototype something in torch-mlir and see if this would be a viable solution.

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, l_x_: "f32[64, 64]:torch.sparse_csr"):
            # File: biknet.py:27, code: return x.sum()
            sum_1: "f32[]" = torch.ops.aten.sum.default(l_x_);  l_x_ = None
            return (sum_1,)

Graph signature: ExportGraphSignature(
  input_specs=[
      InputSpec(
           kind=<InputKind.USER_INPUT: 1>,
           arg=TensorArgument(name='l_x_'),
           target=None,
           layout=torch.sparse_csr)
  ],
  output_specs=[
     OutputSpec(
         kind=<OutputKind.USER_OUTPUT: 1>,
         arg=TensorArgument(name='sum_1'),
         target=None)
 ])
Range constraints: {}
aartbik commented 8 months ago

I was also pointed to the right forum for this discussion, so I "migrated" the discussion from the user forum to this posting in the developer forum.

ezyang commented 8 months ago

I think we would probably accept these PRs. Send us the draft so we can take a quick look.

bdhirsh commented 8 months ago

guideline: nobody is signed up to work on this (so marking low-priority), but accepting PR's

aartbik commented 8 months ago

I don't seem to be able to change assignment in the PyTorch issues repo, but please feel free to assign to me for now!

aartbik commented 8 months ago

To get the ball rolling, I send out a quick and dirty hack that gets this propagated into the traced graph as follows. I would love to hear how we can lift this to production level.

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, l_x_: "f32[2]:torch.sparse_csr"):    <-- note the layout for the l_x type
            # File: biknet.py:100 in forward, code: return x.sum()
            sum_1: "f32[]" = torch.ops.aten.sum.default(l_x_);  l_x_ = None
            return (sum_1,)
masnesral commented 3 months ago

@aartbik I'm helping the team groom "old" issues. Based on recent activity, looks like you're still actively working on this?

Also, totally unrelated: don't I know you from way back from your work on vectorization and alignment analysis?!

aartbik commented 2 months ago

Yeah, I am just a "few" PRs away from getting this working. Currently stuck on https://github.com/pytorch/pytorch/pull/128549, but once that goes in, I suspect all the remaining cases will go much, uch faster (since by then, the "hot fix" will at least be history ;-)

And yes, we know each other really well from our SIMD days and writing lattices for alignment ;-)