A major component of the Cloud TPU runtime has been upgraded. This enables
the following new features on Cloud TPU:
{func}jax.debug.print, {func}jax.debug.callback, and
{func}jax.debug.breakpoint() now work on Cloud TPU
Automatic TPU memory defragmentation
{func}jax.experimental.host_callback is no longer supported on Cloud TPU
with the new runtime component. Please file an issue on the JAX issue
tracker if the new jax.debug APIs
are insufficient for your use case.
The old runtime component will be available for at least the next three
months by setting the environment variable
JAX_USE_PJRT_C_API_ON_TPU=false. If you find you need to disable the new
runtime for any reason, please let us know on the JAX issue
tracker.
Changes
The minimum jaxlib version has been bumped from 0.4.6 to 0.4.7.
Deprecations
CUDA 11.4 support has been dropped. JAX GPU wheels only support
CUDA 11.8 and CUDA 12. Older CUDA versions may work if jaxlib is built
from source.
global_arg_shapes argument of pmap only worked with sharded_jit and has
been removed from pmap. Please migrate to pjit and remove global_arg_shapes
from pmap.
jax.config.jax_jit_pjit_api_merge cannot be disabled anymore.
{func}jax.experimental.jax2tf.convert now supports the native_serialization
parameter to use JAX's native lowering to StableHLO to obtain a
StableHLO module for the entire JAX function instead of lowering each JAX
primitive to a TensorFlow op. This simplifies the internals and increases
the confidence that what you serialize matches the JAX native semantics.
See documentation.
As part of this change the config flag --jax2tf_default_experimental_native_lowering
has been renamed to --jax2tf_native_serialization.
JAX now depends on ml_dtypes, which contains definitions of NumPy types
like bfloat16. These definitions were previously internal to JAX, but have
been split into a separate package to facilitate sharing them with other
projects.
JAX now requires NumPy 1.21 or newer and SciPy 1.7 or newer.
... (truncated)
Commits
f282c25 Add minimal pyproject.toml specifying build system
cfa330b Merge pull request #15283 from JiaYaobo:fix_wald_doc
2d94f76 Merge pull request #15278 from hawkinsp:cudainstall
fbc05ee Remove global_arg_shapes from pmap since it was only used for sharded_jit and...
Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting @dependabot rebase.
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it
- `@dependabot merge` will merge this PR after your CI passes on it
- `@dependabot squash and merge` will squash and merge this PR after your CI passes on it
- `@dependabot cancel merge` will cancel a previously requested merge and block automerging
- `@dependabot reopen` will reopen this PR if it is closed
- `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually
- `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Updates the requirements on jax[cpu] to permit the latest version.
Changelog
Sourced from jax[cpu]'s changelog.
... (truncated)
Commits
f282c25
Add minimal pyproject.toml specifying build systemcfa330b
Merge pull request #15283 from JiaYaobo:fix_wald_doc2d94f76
Merge pull request #15278 from hawkinsp:cudainstallfbc05ee
Remove global_arg_shapes from pmap since it was only used for sharded_jit and...a964ae7
Internal Code Change7200d07
Merge pull request #15286 from hawkinsp:testjobsd9b0f3c
Recommend --local_test_jobs in bazel test command line on GPU.07fc022
Merge pull request #15279 from hawkinsp:versions3a4d0b3
remove scale in wald docstring705b5cc
Add version constraints to CUDA pip wheel dependencies.Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting
@dependabot rebase
.Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)