Open dlwh opened 4 months ago
The backend kernel doesn't support rank > 4 for reduce op. Is it possible for the app to work around the issue by reshaping the tensor, e.g.,
a = jnp.zeros( (2, 3, 4, 5, 6)).reshape(-1, 4, 5, 6)
Thanks! I can work around, particularly for these --> scalar conversions. (But for this case it also seems like a straightforward thing to do on the plugin-end?) I could be wrong but I think any reduction of either 1 or all axes can be written as a reshape -> reduce -> reshape.
Can we leave this open as a sign post?
Is there a guide on Metal perf yet (presumably not JAX-focused, but something close by?)
The reduction dimension is not limited to 1 nor all axes from stablehlo spec, so the pattern(reshape-reduce-reshape) will not resolve all the cases. We will look into whether a more general conversion pattern could be added to jax-metal.
Sure, i meant those are the easy cases, and felt like a minimum. I think you can define something that's correct modulo floating point (and definitely not optimal) with transpose->reshape->reduce->reshape
Description
System info (python version, jaxlib version, accelerator, etc.)