Open 0xc1c4da opened 2 months ago
what do you think about:
with hax.auto_broadcast():
named1_diff = named1 - named1.rename({"T": "T2"})
Basically the only thing stopping this from working is an explicit check I do to avoid accidentally combining arrays where one isn't a subset of the other.
The other thing I could do is relax the check to be "at least one overlapping axis"
I think it is certainly cleaner, but I wouldn't remove the explicit check, wouldn't it be better to explicitly disable the check?
meaning you like with hax.auto_broadcast
?
In torch and jax it is possible to perform an all pairs difference using a one liner black magic represented as follows:
dt_segment_sum_jax = dA_cumsum_jax[:, :, :, :, None] - dA_cumsum_jax[:, :, :, None, :]
This is performed in the reference implementation of Mamba 2
While the aforementioned code is not human readable nor obvious what it is doing, it was not obvious how to represent the equivalent in Haliax due to a subset constraint, however a potential solution is below:
This issue exists provide better support for this kind of operation.