Closed juliuskunze closed 1 month ago
We raise inconsistencies in dynamic shapes using the validation rules.
We might be able to use the host_callback stuff to do this from inside compiled code.
- We decorate shape rules with
@numpy_eval()
from #3923. This ensures XLA dispatch is never wasted on constant shape calculations, even duringjit
compilation (+ no staging occurs). If the NumPy backend turns out to be too slow for some rules, we keep additional optimized NumPy rules for those.
I'm not sure I understand this. How does XLA dispatch get invoked on shape calculations at present? I would guess that most cases where this is currently happening are bugs that would be surfaced by turning on omnistaging.
@shoyer Currently there is no XLA dispatch for shapes since shape rules are written directly in NumPy. In order to allow dynamic/batched shape propagation for masking, we would need traceable shape rules. @numpy_eval()
would be useful to
calculate static, single shapes (i. e. in standard use cases, outside of masking) as fast as in the current implementation, without XLA dispatch, using the same code (i. e. not having two versions of each shape rule). We should be able to get very close to the original performance with #4117.
I think the main problem with this proposal is that we can't do shape checking and gracefully raise a shape error for incompatible dynamic shapes inside a jit
. As I mentioned above, it is technically possible to raise an exception inside jit but I don't think you'd get a nice traceback from user code and the functionality is experimental.
@j-towns I agree, this is probably the main hurdle. How about we record the stack trace of potential shape errors during compilation and attach it to errors raised via host callback?
We killed mask
, and the discussion here is pretty old, so let's close the issue. Thanks to all for the discussion and contributions :)
Currently,
mask
can only support ops with output sizes expressible as polynomials of the input sizes. This excludes:np.where(x)
This is a hassle. @j-towns @mattjj @dougalm What are your thoughts on the following plan to allow these cases?
mask
propagate the logical shapes using corresponding shape rules.vmap(mask)
will then propagate batches of logical shapes using batched shape rules. We raise inconsistencies in dynamic shapes using the validation rules.mask
will be parameter-free. A (batched) masked function will take padded inputs and (batches of) logical input shapes, and return padded outputs and (batches of) logical output shapes.@numpy_eval()
from https://github.com/google/jax/pull/3923. This ensures XLA dispatch is never wasted on constant shape calculations, even duringjit
compilation (+ no staging occurs). If the NumPy backend turns out to be too slow for some rules, we keep additional optimized NumPy rules for those.Poly
class and all other things polymorphic. This will remove the need for polymorphic special cases in shape rules.jnp.sum(x) / shape_as_value(x.shape)[0]
users can writejnp.sum(x) / logical_shape(x)[0]
.logical_shape
will retrieve the logical shape from the underlyingMaskTrace
. We remove themasking.shape_envs
variable.