Open drpngx opened 1 year ago
To answer my own question:
def AssertShape(x: jnp.array, expected_shape) -> None:
# Shape must be a python element with static size.
xs = x.shape
if len(xs) != len(expected_shape):
raise ValueError(
f'Wrong rank: got [{len(xs)}] {x.shape}, expected [{len(expected_shape)}] {expected_shape}')
for k, d in enumerate(expected_shape):
if xs[k] != d:
raise ValueError(f'Wrong shape at dim {k}: got {x.shape}, expected {expected_shape}')
I tried
and got
(BTW, note the double period)