Support for NumPy 1.18 has been dropped, per the [deprecation policy](https://jax.readthedocs.io/en/latest/ deprecation.html). Please upgrade to a supported NumPy version.
The host_callback primitives have been simplified to drop the special autodiff handling for hcb.id_tap and id_print. From now on, only the primals are tapped. The old behavior can be obtained (for a limited time) by setting the JAX_HOST_CALLBACK_AD_TRANSFORMS environment variable, or the --flax_host_callback_ad_transforms flag. Additionally, added documentation for how to implement the old behavior using JAX custom AD APIs ({jax-issue}[#8678](https://github.com/google/jax/issues/8678)).
Sorting now matches the behavior of NumPy for 0.0 and NaN regardless of the bit representation. In particular, 0.0 and -0.0 are now treated as equivalent, where previously -0.0 was treated as less than 0.0. Additionally all NaN representations are now treated as equivalent and sorted to the end of the array. Previously negative NaN values were sorted to the front of the array, and NaN values with different internal bit representations were not treated as equivalent, and were sorted according to those bit patterns ({jax- issue}[#9178](https://github.com/google/jax/issues/9178)).
{func}jax.numpy.unique now treats NaN values in the same way as np.unique in NumPy versions 1.21 and newer: at most one NaN value will appear in the uniquified output ({jax-issue}9184).
Bug fixes:
host_callback now supports ad_checkpoint.checkpoint ({jax-issue}[#8907](https://github.com/google/jax/issues/8907)).
Added a new debugging flag/environment variable JAX_DUMP_IR_TO=/path. If set, JAX dumps the MHLO/HLO IR it generates for each computation to a file under the given path.
Added jax.ensure_compile_time_eval to the public api ({jax-issue}[#7987](https://github.com/google/jax/issues/7987)).
jax2tf now supports a flag jax2tf_associative_scan_reductions to change the lowering for associative reductions, e.g., jnp.cumsum, to behave like JAX on CPU and GPU (to use an associative scan). See the jax2tf README for more details ({jax-issue}[#9189](https://github.com/google/jax/issues/9189)).
JAX release v0.2.26
Bug fixes:
Out-of-bounds indices to jax.ops.segment_sum will now be handled with FILL_OR_DROP semantics, as documented. This primarily afects the reverse-mode derivative, where gradients corresponding to out-of-bounds indices will now be returned as 0. (#8634).
jax2tf will force the converted code to use XLA for the code fragments under jax.jit, e.g., most jax.numpy functions (#7839).
jax.jit(f).lower(...).compiler_ir() now defaults to the MHLO dialect if no
dialect= is passed.
The jax.jit(f).lower(...).compiler_ir(dialect='mhlo') now returns an MLIR
ir.Module object instead of its string representation.
jaxlib 0.1.77 (Unreleased)
Changes
Bazel 5.0.0 is now required to build jaxlib.
jaxlib 0.1.76 (Jan 27, 2022)
New features
Includes precompiled SASS for NVidia compute capability 8.0 GPUS
(e.g. A100). Removes precompiled SASS for compute capability 6.1 so as not
to increase the number of compute capabilities: GPUs with compute capability
6.1 can use the 6.0 SASS.
With jaxlib 0.1.76, JAX uses the MHLO MLIR dialect as its primary target compiler IR
by default.
Breaking changes
Support for NumPy 1.18 has been dropped, per the
deprecation policy.
Please upgrade to a supported NumPy version.
Bug fixes
Fixed a bug where apparently identical pytreedef objects constructed by different routes
do not compare as equal (#9066).
The JAX jit cache requires two static arguments to have identical types for a cache hit (#9311).
Support for NumPy 1.18 has been dropped, per the
deprecation policy.
Please upgrade to a supported NumPy version.
The host_callback primitives have been simplified to drop the
special autodiff handling for hcb.id_tap and id_print.
From now on, only the primals are tapped. The old behavior can be
obtained (for a limited time) by setting the JAX_HOST_CALLBACK_AD_TRANSFORMS
environment variable, or the --flax_host_callback_ad_transforms flag.
Additionally, added documentation for how to implement the old behavior
using JAX custom AD APIs ({jax-issue}[#8678](https://github.com/google/jax/issues/8678)).
Sorting now matches the behavior of NumPy for 0.0 and NaN regardless of the
bit representation. In particular, 0.0 and -0.0 are now treated as equivalent,
where previously -0.0 was treated as less than 0.0. Additionally all NaN
representations are now treated as equivalent and sorted to the end of the array.
Previously negative NaN values were sorted to the front of the array, and NaN
values with different internal bit representations were not treated as equivalent, and
Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting @dependabot rebase.
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it
- `@dependabot merge` will merge this PR after your CI passes on it
- `@dependabot squash and merge` will squash and merge this PR after your CI passes on it
- `@dependabot cancel merge` will cancel a previously requested merge and block automerging
- `@dependabot reopen` will reopen this PR if it is closed
- `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually
- `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Bumps jax from 0.2.8 to 0.2.28.
Release notes
Sourced from jax's releases.
... (truncated)
Changelog
Sourced from jax's changelog.
... (truncated)
Commits
3acbd44
Remove isinstance checksdcca99b
Remove path from the serde API as tspec encompasses those things.4e47de6
Add the cache back now that Mesh's hash is also being hashed on `self.dev...5efa285
Switch Cuda 11.4 cudnn 8.0.5 to build with cuda 11.1 cudnn 8.0.5 instead.39786c6
Merge pull request #9394 from jakevdp:pre-commit-versionsfe14530
Merge pull request #9391 from jakevdp:fix-constant-handlere3fe4a2
Merge pull request #9316 from mattjj:djax-now-5d9dcd13
djax: let make_jaxpr build dyn shape jaxprsf80887e
Couple of changes because of the serialization inconsistencies being observed.63bac94
Merge pull request #9393 from jakevdp:delete-codeDependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting
@dependabot rebase
.Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)