tensorflow / lingvo

Lingvo
Apache License 2.0
2.81k stars 443 forks source link

Feature request: lingvo.jax.asserts.HasShape #332

Open drpngx opened 1 year ago

drpngx commented 1 year ago

I tried

def AssertShape(x: jnp.array, shape) -> None:
  if not jnp.array_equal(x.shape, shape):
    raise ValueError(f'Shape mismatch: found {x.shape}, expected: {shape}')

and got

jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..

(BTW, note the double period)

drpngx commented 1 year ago

To answer my own question:

def AssertShape(x: jnp.array, expected_shape) -> None:
  # Shape must be a python element with static size.
  xs = x.shape
  if len(xs) != len(expected_shape):
    raise ValueError(
      f'Wrong rank: got [{len(xs)}] {x.shape}, expected [{len(expected_shape)}] {expected_shape}')
  for k, d in enumerate(expected_shape):
    if xs[k] != d:
      raise ValueError(f'Wrong shape at dim {k}: got {x.shape}, expected {expected_shape}')