Support for Python 3.7 has been dropped, in accordance with JAX's
{ref}version-support-policy.
We introduce jax.Array which is a unified array type that subsumes
DeviceArray, ShardedDeviceArray, and GlobalDeviceArray types in JAX.
The jax.Array type helps make parallelism a core feature of JAX,
simplifies and unifies JAX internals, and allows us to unify jit and
pjit. jax.Array has been enabled by default in JAX 0.4 and makes some
breaking change to the pjit API. The jax.Array migration
guide can
help you migrate your codebase to jax.Array. You can also look at the
Distributed arrays and automatic parallelization
tutorial to understand the new concepts.
PartitionSpec and Mesh are now out of experimental. The new API endpoints
are jax.sharding.PartitionSpec and jax.sharding.Mesh.
jax.experimental.maps.Mesh and jax.experimental.PartitionSpec are
deprecated and will be removed in 3 months.
with_sharding_constraints new public endpoint is
jax.lax.with_sharding_constraint.
If using ABSL flags together with jax.config, the ABSL flag values are no
longer read or written after the JAX configuration options are initially
populated from the ABSL flags. This change improves performance of reading
jax.config options, which are used pervasively in JAX.
The jax2tf.call_tf function now uses for TF lowering the first TF
device of the same platform as used by the embedding JAX computation.
Before, it was using the 0th device for the JAX-default backend.
A number of jax.numpy functions now have their arguments marked as
positional-only, matching NumPy.
jnp.msort is now deprecated, following the deprecation of np.msort in numpy 1.24.
It will be removed in a future release, in accordance with the {ref}api-compatibility
policy. It can be replaced with jnp.sort(a, axis=0).
The implementation of jit and pjit has been merged. Merging jit and pjit
changes the internals of JAX without affecting the public API of JAX.
Before, jit was a final style primitive. Final style means that the creation
of jaxpr was delayed as much as possible and transformations were stacked
on top of each other. With the jit-pjit implementation merge, jit
becomes an initial style primitive which means that we trace to jaxpr
as early as possible. For more information see
this section in autodidax.
Moving to initial style should simplify JAX's internals and make
development of features like dynamic shapes, etc easier.
You can disable it only via the environment variable i.e.
os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'.
The merge must be disabled via an environment variable since it affects JAX
at import time so it needs to be disabled before jax is imported.
axis_resources argument of with_sharding_constraint is deprecated.
Please use shardings instead. There is no change needed if you were using
axis_resources as an arg. If you were using it as a kwarg, then please
use shardings instead. axis_resources will be removed after 3 months
from Feb 13, 2023.
added the {mod}jax.typing module, with tools for type annotations of JAX
functions.
The following names have been deprecated:
jax.xla.Device and jax.interpreters.xla.Device: use jax.Device.
jax.experimental.maps.Mesh. Use jax.sharding.Mesh
instead.
jax.experimental.pjit.NamedSharding: use jax.sharding.NamedSharding.
jax.experimental.pjit.PartitionSpec: use jax.sharding.PartitionSpec.
jax.interpreters.pxla.Mesh: use jax.sharding.Mesh.
jax.interpreters.pxla.PartitionSpec: use jax.sharding.PartitionSpec.
Breaking Changes
the initial argument to reduction functions like :func:jax.numpy.sum
is now required to be a scalar, consistent with the corresponding NumPy API.
The previous behavior of broadcating the output against non-scalar initial
values was an unintentional implementation detail ({jax-issue}[#14446](https://github.com/google/jax/issues/14446)).
jaxlib 0.4.4 (Feb 16, 2023)
Breaking changes
Support for NVIDIA Kepler series GPUs has been removed from the default
jaxlib builds. If Kepler support is needed, it is still possible to
build jaxlib from source with Kepler support (via the
--cuda_compute_capabilities=sm_35 option to build.py), however note
that CUDA 12 has completely dropped support for Kepler GPUs.
jax 0.4.3 (Feb 8, 2023)
Breaking changes
Deleted {func}jax.scipy.linalg.polar_unitary, which was a deprecated JAX
extension to the scipy API. Use {func}jax.scipy.linalg.polar instead.
Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting @dependabot rebase.
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it
- `@dependabot merge` will merge this PR after your CI passes on it
- `@dependabot squash and merge` will squash and merge this PR after your CI passes on it
- `@dependabot cancel merge` will cancel a previously requested merge and block automerging
- `@dependabot reopen` will reopen this PR if it is closed
- `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually
- `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Bumps jax from 0.2.8 to 0.4.4.
Release notes
Sourced from jax's releases.
... (truncated)
Changelog
Sourced from jax's changelog.
... (truncated)
Commits
58e46b4
Prepare for jax and jaxlib 0.4.4 releasec6a99b6
Remove jax.interpreters.xla.lower_fun.a9e886f
[jax2tf] Enable all native lowering jax2tf tests454e4de
[shape_poly] Fix the lowering for symbolic dimension expressions for divisiond0b42f2
Fix the simple bug on call_tf.replace_non_float and add unittest for floating...26045c4
removecore.{aval_method,aval_property}
d8514d0
Merge pull request #14500 from jakevdp:bcsr-matmul-test0af9fff
Replace uses of deprecated JAX sharding APIs with their new names in jax.shar...1b2a318
removecore.axis_substitution_rules
768960b
Fix pytype errors.Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting
@dependabot rebase
.Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)