Closed PeaBrane closed 9 months ago
Correct we currently don't support multiple inputs. We don't have bandwidth to add support for this at the moment. I can fix the documentation to make it more clear. Contributions are welcome if anybody is interested.
@ThomasRaoux I see, thank you.
Digging a bit deeper into the code, it seems like the associative scan is meant to be applied independently to all the tensors in the input tuple, and the tensors in the tuple do not "interact" with each other.
I am interested in helping add support for multiple inputs to associative scan, but I need to read a bit more into the relevant codes to see if I have the ability to contribute anything useful.
@ThomasRaoux I see, thank you.
Digging a bit deeper into the code, it seems like the associative scan is meant to be applied independently to all the tensors in the input tuple, and the tensors in the tuple do not "interact" with each other.
I don't think the code makes any assumptions on what is done within the scan function. Basically we apply the region of the scan multiple times so what is happening within the region doesn't really matter. What needs to be done is passing multiple inputs and handling multiple accumulators in case the function has multiple results.
That being said I expect the change to be non-trivial as there are quite a bit of places to update.
I am interested in helping add support for multiple inputs to associative scan, but I need to read a bit more into the relevant codes to see if I have the ability to contribute anything useful.
Of course, no pressure.
@PeaBrane have you already started working on this? Would you mind me having a stab at it?
@lezcano No I have not. Please feel free to take on this! (Out or curiosity, would this be used as backend for torch.associative_scan?)
Yep! You can find the PR that adds the initial plumbing within torch.compile
at https://github.com/pytorch/pytorch/pull/106581 and https://github.com/pytorch/pytorch/pull/106581
https://github.com/openai/triton/pull/2947 added support for a tuple of input tensors. Thanks @lezcano.
It supports only tensors of the same type. When tensors have different types, triton fails with error: 'tt.scan' op requires the same element type for all operands and results
Script:
Traceback:
2947 added support for a tuple of input tensors. Thanks @lezcano. It supports only tensors of the same type. When tensors have different types, triton fails with
error: 'tt.scan' op requires the same element type for all operands and results
Script:
Traceback:
this is probably just a missed restriction. It should be easy to make a patch that remove that restriction and adds a simple test to make sure it works.
Will submit a fix either today or early next week.
I am trying to use
associative_scan
to compute the EMA of a signal, and my example is adapted from this issue, with the exception that I am trying to pass multiple tensors (as a tuple) to theassociative_scan
function.According to the documentation, this function should be able to support a tuple of tensors, but I keep getting the error