tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.16k stars 1.08k forks source link

MultivariateNormal* constructor crashes with Numpy 2.0 #1814

Open slinderman opened 2 weeks ago

slinderman commented 2 weeks ago

I updated to Numpy 2.0 and found that the MultivariateNormalDiag and MultivariateNormalFullCovariance constructors crashed because np.issctype has been removed. Is Numpy 2.0 supported, or will it be soon?

Here is a simple repro:

import numpy as np
import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

print(f"numpy version: {np.__version__}")
print(f"jax version: {jax.__version__}")
print(f"tfp verison: {tfp.__version__}")

# works fine
nml = tfd.Normal(jnp.zeros(3), jnp.ones(3))

# fails with numpy 2.0
mvn = tfd.MultivariateNormalDiag(jnp.zeros(3), jnp.ones(3))

# also fails with same error numpy 2.0
# mvn = tfd.MultivariateNormalFullCovariance(jnp.zeros(3), jnp.eye(3))

On my machine with Python 3.10, it produces the following output:

numpy version: 2.0.0
jax version: 0.4.30
tfp verison: 0.24.0
Traceback (most recent call last):
  File "/Users/scott/Projects/dynamax/tfp_debug_20240618.py", line 15, in <module>
    mvn = tfd.MultivariateNormalDiag(jnp.zeros(3), jnp.ones(3))
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/decorator.py", line 232, in fun
    return caller(func, *(extras + args), **kw)
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py", line 342, in wrapped_init
    default_init(self_, *args, **kwargs)
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_diag.py", line 209, in __init__
    super(MultivariateNormalDiag, self).__init__(
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/decorator.py", line 232, in fun
    return caller(func, *(extras + args), **kw)
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py", line 342, in wrapped_init
    default_init(self_, *args, **kwargs)
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_linear_operator.py", line 205, in __init__
    super(MultivariateNormalLinearOperator, self).__init__(
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/decorator.py", line 232, in fun
    return caller(func, *(extras + args), **kw)
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py", line 342, in wrapped_init
    default_init(self_, *args, **kwargs)
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py", line 244, in __init__
    dtype = self.bijector.forward_dtype(self.distribution.dtype)
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py", line 1705, in forward_dtype
    input_dtype = nest.map_structure_up_to(
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py", line 324, in map_structure_up_to
    return map_structure_with_tuple_paths_up_to(
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py", line 353, in map_structure_with_tuple_paths_up_to
    return dm_tree.map_structure_with_path_up_to(
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tree/__init__.py", line 778, in map_structure_with_path_up_to
    results.append(func(*path_and_values))
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py", line 326, in <lambda>
    lambda _, *args: func(*args),  # Discards path.
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py", line 1707, in <lambda>
    lambda x: dtype_util.convert_to_dtype(x, dtype=self.dtype),
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/internal/dtype_util.py", line 247, in convert_to_dtype
    elif np.issctype(tensor_or_dtype):
  File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/numpy/__init__.py", line 397, in __getattr__
    raise AttributeError(
AttributeError: `np.issctype` was removed in the NumPy 2.0 release. Use `issubclass(rep, np.generic)` instead.. Did you mean: 'isdtype'?