jax.sharding.OpShardingSharding has been renamed to jax.sharding.GSPMDSharding.
jax.sharding.OpShardingSharding will be removed in 3 months from Feb 17, 2023.
The following jax.Array methods are deprecated and will be removed 3 months from
Feb 23 2023:
jax.Array.broadcast: use {func}jax.lax.broadcast instead.
jax.Array.broadcast_in_dim: use {func}jax.lax.broadcast_in_dim instead.
jax.Array.split: use {func}jax.numpy.split instead.
jax 0.4.4 (Feb 16, 2023)
Changes
The implementation of jit and pjit has been merged. Merging jit and pjit
changes the internals of JAX without affecting the public API of JAX.
Before, jit was a final style primitive. Final style means that the creation
of jaxpr was delayed as much as possible and transformations were stacked
on top of each other. With the jit-pjit implementation merge, jit
becomes an initial style primitive which means that we trace to jaxpr
as early as possible. For more information see
this section in autodidax.
Moving to initial style should simplify JAX's internals and make
development of features like dynamic shapes, etc easier.
You can disable it only via the environment variable i.e.
os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'.
The merge must be disabled via an environment variable since it affects JAX
at import time so it needs to be disabled before jax is imported.
axis_resources argument of with_sharding_constraint is deprecated.
Please use shardings instead. There is no change needed if you were using
axis_resources as an arg. If you were using it as a kwarg, then please
use shardings instead. axis_resources will be removed after 3 months
from Feb 13, 2023.
added the {mod}jax.typing module, with tools for type annotations of JAX
functions.
The following names have been deprecated:
jax.xla.Device and jax.interpreters.xla.Device: use jax.Device.
jax.experimental.maps.Mesh. Use jax.sharding.Mesh
instead.
jax.experimental.pjit.NamedSharding: use jax.sharding.NamedSharding.
jax.experimental.pjit.PartitionSpec: use jax.sharding.PartitionSpec.
jax.interpreters.pxla.Mesh: use jax.sharding.Mesh.
jax.interpreters.pxla.PartitionSpec: use jax.sharding.PartitionSpec.
Breaking Changes
the initial argument to reduction functions like :func:jax.numpy.sum
is now required to be a scalar, consistent with the corresponding NumPy API.
The previous behavior of broadcating the output against non-scalar initial
values was an unintentional implementation detail ({jax-issue}[#14446](https://github.com/google/jax/issues/14446)).
Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting @dependabot rebase.
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it
- `@dependabot merge` will merge this PR after your CI passes on it
- `@dependabot squash and merge` will squash and merge this PR after your CI passes on it
- `@dependabot cancel merge` will cancel a previously requested merge and block automerging
- `@dependabot reopen` will reopen this PR if it is closed
- `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually
- `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Updates the requirements on jax[cpu] to permit the latest version.
Changelog
Sourced from jax[cpu]'s changelog.
... (truncated)
Commits
a002643
Fix stale reference to util.prod.a9421a8
Merge pull request #14728 from mattjj:custom-vjp-bwd-wrapped-funbf07395
[custom_vjp] bwd function should not be WrappedFun, may run multiple timesabc6c9b
[sparse] adjust tolerance on bcoo_dot_general_sampled3abae68
Rollforward of Add a fastpath to pmap_lib for sharding np.ndarray directly in...c73cc49
Use in_shardings and out_shardings since those are the new arguments that pji...a21bdad
Merge pull request #14738 from jakevdp:update-sphinx55d9c06
DOC: update sphinx & sphinx-autodoc-typehintsb348fce
Merge pull request #14736 from jakevdp:fix-rtde8a4e64
Merge pull request #14737 from hawkinsp:kaggleDependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting
@dependabot rebase
.Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)