JAX's own initializers expect a Sequence[int] as their shape type, not an Iterable[int], and pytype will report an error about that mismatch in a future version of JAX.
A corrected definition is:
Initializer = Callable[[jnp.ndarray, Sequence[int], jnp.dtype], jnp.ndarray]
[JAX] Fix code that defined an Initializer type as:
Initializer = Callable[[jnp.ndarray, Iterable[int], jnp.dtype], jnp.ndarray]
JAX's own initializers expect a Sequence[int] as their shape type, not an Iterable[int], and pytype will report an error about that mismatch in a future version of JAX.
A corrected definition is: Initializer = Callable[[jnp.ndarray, Sequence[int], jnp.dtype], jnp.ndarray]