patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.05k stars 49 forks source link

Allow two variadic shapes when it makes sense #184

Open kho opened 4 months ago

kho commented 4 months ago

Currently if I write a function like the following:

def mask_invalid(
    x: Shaped[Array, '*B *C'], mask: Bool[Array, '*B']
) -> Shaped[Array, '*B *C']:
  return jnp.where(jnp.expand_dims(mask, range(mask.ndim, x.ndim)), x, 0)

I will get a "ValueError: Cannot use variadic specifiers (*name or ...) more than once.". That seems a bit overly restrictive: it's certainly possible to infer *C from the shape of mask. Can the requirement be relaxed to something like "only 1 non-determinable variadic specifier can be used in each shape", i.e. the following algorithm?

# Each shape is the list of dimension names and whether its variadic.
# Returns the non-determinable variadic names.
def eliminate_determinable_variadic_shapes(*shapes: Sequence[tuple[str, bool]]):
  remaining = set(range(len(shapes)))
  # Variadic dimension names that can be determined.
  deteriminable = set()
  while True:
    new_remaining = set()
    for i in remaining:
      variadics = [
          name
          for name, variadic in shapes[i]
          if variadic and name not in deteriminable
      ]
      if len(variadics) > 1:
        new_remaining.add(i)
      elif len(variadics) == 1:
        deteriminable.add(variadics[0])
        print(variadics[0], 'becomes determinable because of', shapes[i])
    if len(remaining) == len(new_remaining):
      return set([
          name
          for shape in shapes
          for name, variadic in shape
          if variadic and name not in deteriminable
      ])
    remaining = new_remaining
patrick-kidger commented 4 months ago

IIUC, you're basically trying to resolve the variadic shapes one-at-a-time? Indeed that'd be possible in-principle but might be fairly tricky to implement -- right now we leave checking this to the runtime type checker, which is what gets to determine the order in which arguments are checked.

We don't have to do that. We could use the runtime type checker just to traverse all the annotations (i.e. to handle nested annotations like tuple[dict[str, Float[Array, ...), record all the annotations it comes across as the ones we want to check against, and then do shape-checks ourselves using an algorithm like the one you describe.

In practice I'm afraid that'd be a fair amount of work, that might end up being fairly fragile (it'd involve using a dynamic context I think), for quite a niche feature, I'm afraid.