Operations with dimensions in presence of jax2tf shape polymorphism have
been generalized to work in more scenarios, by converting the symbolic
dimension to JAX arrays. Operations involving symbolic dimensions and
np.ndarray now can raise errors when the result is used as a shape value
({jax-issue}[#14106](https://github.com/google/jax/issues/14106)).
jaxpr objects now raise an error on attribute setting in order to avoid
problematic mutations ({jax-issue}14102)
Changes
{func}jax2tf.call_tf has a new parameter has_side_effects (default True)
that can be used to declare whether an instance can be removed or replicated
by JAX optimizations such as dead-code elimination ({jax-issue}[#13980](https://github.com/google/jax/issues/13980)).
jaxlib 0.4.2 (Jan 20, 2023)
Changes
Set JAX_USE_PJRT_C_API_ON_TPU=1 to enable new Cloud TPU runtime, featuring
automatic device memory defragmentation.
jax 0.4.1 (Dec 13, 2022)
Changes
Support for Python 3.7 has been dropped, in accordance with JAX's
{ref}version-support-policy.
We introduce jax.Array which is a unified array type that subsumes
DeviceArray, ShardedDeviceArray, and GlobalDeviceArray types in JAX.
The jax.Array type helps make parallelism a core feature of JAX,
simplifies and unifies JAX internals, and allows us to unify jit and
pjit. jax.Array has been enabled by default in JAX 0.4 and makes some
breaking change to the pjit API. The jax.Array migration
guide can
help you migrate your codebase to jax.Array. You can also look at the
Distributed arrays and automatic parallelization
tutorial to understand the new concepts.
PartitionSpec and Mesh are now out of experimental. The new API endpoints
are jax.sharding.PartitionSpec and jax.sharding.Mesh.
jax.experimental.maps.Mesh and jax.experimental.PartitionSpec are
deprecated and will be removed in 3 months.
with_sharding_constraints new public endpoint is
jax.lax.with_sharding_constraint.
If using ABSL flags together with jax.config, the ABSL flag values are no
longer read or written after the JAX configuration options are initially
populated from the ABSL flags. This change improves performance of reading
jax.config options, which are used pervasively in JAX.
The jax2tf.call_tf function now uses for TF lowering the first TF
device of the same platform as used by the embedding JAX computation.
Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting @dependabot rebase.
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it
- `@dependabot merge` will merge this PR after your CI passes on it
- `@dependabot squash and merge` will squash and merge this PR after your CI passes on it
- `@dependabot cancel merge` will cancel a previously requested merge and block automerging
- `@dependabot reopen` will reopen this PR if it is closed
- `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually
- `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Updates the requirements on jax to permit the latest version.
Changelog
Sourced from jax's changelog.
... (truncated)
Commits
838bc45
Merge pull request #14148 from skye:versionc4ad27c
Update libtpu version for jaxlib 0.4.2 release (again)09794be
Usejax.config
instead ofconfig
because pickle does not like using the c...5aea7d9
[sparse] Add function that fixes out-of-bound indices.7d2031b
Merge pull request #14135 from skye:version9a7f29a
Update WORKSPACE for jaxlib 0.4.2 release (again)b621373
Cache the creation of ClosedJaxpr in pjit_transpose which if not cached break...bbccf55
Merge pull request #14131 from jakevdp:mypy-update94af71a
CI: fix mypy jaxlib version7064be1
Skip unneccessary unflattening of avals in pjit lowering path.Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting
@dependabot rebase
.Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)