google / edward2

A simple probabilistic programming language.
Apache License 2.0
679 stars 75 forks source link

[JAX] Fix code that defined an Initializer type as: #558

Closed copybara-service[bot] closed 1 year ago

copybara-service[bot] commented 1 year ago

[JAX] Fix code that defined an Initializer type as:

Initializer = Callable[[jnp.ndarray, Iterable[int], jnp.dtype], jnp.ndarray]

JAX's own initializers expect a Sequence[int] as their shape type, not an Iterable[int], and pytype will report an error about that mismatch in a future version of JAX.

A corrected definition is: Initializer = Callable[[jnp.ndarray, Sequence[int], jnp.dtype], jnp.ndarray]