scikit-hep / pyhf

pure-Python HistFactory implementation with tensors and autodiff
https://pyhf.readthedocs.io/
Apache License 2.0
279 stars 83 forks source link

jax v0.4.1's jax.Array breaks schema #2078

Closed matthewfeickert closed 1 year ago

matthewfeickert commented 1 year ago

In JAX v0.4.1

We introduce jax.Array which is a unified array type that subsumes DeviceArray, ShardedDeviceArray, and GlobalDeviceArray types in JAX. The jax.Array type helps make parallelism a core feature of JAX, simplifies and unifies JAX internals, and allows us to unify jit and pjit. jax.Array has been enabled by default in JAX 0.4 and makes some breaking change to the pjit API. The jax.Array migration guide can help you migrate your codebase to jax.Array. You can also look at the Distributed arrays and automatic parallelization tutorial to understand the new concepts.

This causes test_schema_tensor_type_allowed to fail

$ pytest tests/test_schema.py -k 'test_schema_tensor_type_allowed[jax]'
...
>           raise pyhf.exceptions.InvalidSpecification(err, schema_name)
E           pyhf.exceptions.InvalidSpecification: Array([10.], dtype=float64) is not of type 'array'.
E               Path: channels[0].samples[0].data
E               Instance: [10.] Schema: model.json

As this is failing on schema tests this might require a patch release.

matthewfeickert commented 1 year ago

This particular issue can be fixed with just

$ git diff
diff --git a/src/pyhf/tensor/jax_backend.py b/src/pyhf/tensor/jax_backend.py
index 2e05039f..9cd46a58 100644
--- a/src/pyhf/tensor/jax_backend.py
+++ b/src/pyhf/tensor/jax_backend.py
@@ -2,6 +2,7 @@

 config.update('jax_enable_x64', True)

+from jax import Array
 import jax.numpy as jnp
 from jax.scipy.special import gammaln, xlogy
 from jax.scipy import special
@@ -54,10 +55,10 @@ class jax_backend:
     __slots__ = ['name', 'precision', 'dtypemap', 'default_do_grad']

     #: The array type for jax
-    array_type = jnp.DeviceArray
+    array_type = Array

     #: The array content type for jax
-    array_subtype = jnp.DeviceArray
+    array_subtype = Array

     def __init__(self, **kwargs):
         self.name = 'jax'
=

but depending on how much else needs to get updated I'm not sure if this will be just a patch fix or if this should be viewed as part of the work for the next minor release given that the only thing that fails is schema validation checks, that would pass on an older release.

edit: At runtime this is basically fine, so I think this can be moved to the next minor release.

kratsg commented 1 year ago

Honestly, this is going to impact one person for right now @phinate but that's probably ok.

phinate commented 1 year ago

Honestly, this is going to impact one person for right now @phinate but that's probably ok.

goes to cry in corner :'(

But in all seriousness, this is unlikely to affect me as I'm still bypassing validation right now in experiments (something about my fork seemed to necessitate it), and also I'm not actively working on anything either. So go for it :)

matthewfeickert commented 1 year ago

Thanks for the feedback @phinate. :+1: I'll go ahead and merge in PR #2079 then, but after your thesis defense if you need any new releases of pyhf to help make things easier on the neos or relaxed side just let us know and we can work with you on that. I want to make it easier for us to cut releases as needed.