Closed aartbik closed 2 months ago
I hacked a bit around and was able to get a traced graph with the information I am (for now) looking for. In order to do this I had to
With that, I am able to generate the following traced graph. This will enable me to prototype something in torch-mlir and see if this would be a viable solution.
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, l_x_: "f32[64, 64]:torch.sparse_csr"):
# File: biknet.py:27, code: return x.sum()
sum_1: "f32[]" = torch.ops.aten.sum.default(l_x_); l_x_ = None
return (sum_1,)
Graph signature: ExportGraphSignature(
input_specs=[
InputSpec(
kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='l_x_'),
target=None,
layout=torch.sparse_csr)
],
output_specs=[
OutputSpec(
kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='sum_1'),
target=None)
])
Range constraints: {}
I was also pointed to the right forum for this discussion, so I "migrated" the discussion from the user forum to this posting in the developer forum.
I think we would probably accept these PRs. Send us the draft so we can take a quick look.
guideline: nobody is signed up to work on this (so marking low-priority), but accepting PR's
I don't seem to be able to change assignment in the PyTorch issues repo, but please feel free to assign to me for now!
To get the ball rolling, I send out a quick and dirty hack that gets this propagated into the traced graph as follows. I would love to hear how we can lift this to production level.
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, l_x_: "f32[2]:torch.sparse_csr"): <-- note the layout for the l_x type
# File: biknet.py:100 in forward, code: return x.sum()
sum_1: "f32[]" = torch.ops.aten.sum.default(l_x_); l_x_ = None
return (sum_1,)
@aartbik I'm helping the team groom "old" issues. Based on recent activity, looks like you're still actively working on this?
Also, totally unrelated: don't I know you from way back from your work on vectorization and alignment analysis?!
Yeah, I am just a "few" PRs away from getting this working. Currently stuck on https://github.com/pytorch/pytorch/pull/128549, but once that goes in, I suspect all the remaining cases will go much, uch faster (since by then, the "hot fix" will at least be history ;-)
And yes, we know each other really well from our SIMD days and writing lattices for alignment ;-)
š The feature, motivation and pitch
For background discussion, please see:
https://discuss.pytorch.org/t/connecting-pytorch-sparse-tensors-with-mlir/195145
Given the following code:
Then prog1 yields the expected traced graph IR as follows:
However, computing prog2 throws an exception when trying to compute the tensor's size (due to having no strides on CSR format when trying to multiply shapes and strides for size):
torch._dynamo.exc.InternalTorchDynamoError: Sparse CSR tensors do not have strides
A much more desired outcome would be to yield a similar traced graph as for the dense case, but where argument l_x is annotated with the
torch.sparse_csr
layout. That would enable compilers (in particular "sparse compilers") to build an IR where sparsity is reflected in the types (and not the code yet).Alternatives
No response
Additional context
No response
cc @alexsamardzic @nikitaved @pearu @cpuhrsch @amjames @bhosmer @jcaip @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @wconstab