jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.97k stars 2.75k forks source link

Allow mask of strided ops #3893

Closed juliuskunze closed 1 month ago

juliuskunze commented 4 years ago

Currently, mask can only support ops with output sizes expressible as polynomials of the input sizes. This excludes:

This is a hassle. @j-towns @mattjj @dougalm What are your thoughts on the following plan to allow these cases?

j-towns commented 4 years ago

We raise inconsistencies in dynamic shapes using the validation rules.

We might be able to use the host_callback stuff to do this from inside compiled code.

shoyer commented 4 years ago
  • We decorate shape rules with @numpy_eval() from #3923. This ensures XLA dispatch is never wasted on constant shape calculations, even during jit compilation (+ no staging occurs). If the NumPy backend turns out to be too slow for some rules, we keep additional optimized NumPy rules for those.

I'm not sure I understand this. How does XLA dispatch get invoked on shape calculations at present? I would guess that most cases where this is currently happening are bugs that would be surfaced by turning on omnistaging.

juliuskunze commented 4 years ago

@shoyer Currently there is no XLA dispatch for shapes since shape rules are written directly in NumPy. In order to allow dynamic/batched shape propagation for masking, we would need traceable shape rules. @numpy_eval() would be useful to calculate static, single shapes (i. e. in standard use cases, outside of masking) as fast as in the current implementation, without XLA dispatch, using the same code (i. e. not having two versions of each shape rule). We should be able to get very close to the original performance with #4117.

j-towns commented 4 years ago

I think the main problem with this proposal is that we can't do shape checking and gracefully raise a shape error for incompatible dynamic shapes inside a jit. As I mentioned above, it is technically possible to raise an exception inside jit but I don't think you'd get a nice traceback from user code and the functionality is experimental.

juliuskunze commented 4 years ago

@j-towns I agree, this is probably the main hurdle. How about we record the stack trace of potential shape errors during compilation and attach it to errors raised via host callback?

mattjj commented 1 month ago

We killed mask, and the discussion here is pretty old, so let's close the issue. Thanks to all for the discussion and contributions :)