Added {func}jax.numpy.unstack and {func}jax.numpy.cumulative_sum,
following their addition in the array API 2023 standard, soon to be
adopted by NumPy.
Added a new config option jax_cpu_collectives_implementation to select the
implementation of cross-process collective operations used by the CPU backend.
Choices available are 'none'(default), 'gloo' and 'mpi' (requires jaxlib 0.4.26).
If set to 'none', cross-process collective operations are disabled.
Changes
{func}jax.pure_callback, {func}jax.experimental.io_callback
and {func}jax.debug.callback now use {class}jax.Array instead
of {class}np.ndarray. You can recover the old behavior by transforming
the arguments via jax.tree.map(np.asarray, args) before passing them
to the callback.
complex_arr.astype(bool) now follows the same semantics as NumPy, returning
False where complex_arr is equal to 0 + 0j, and True otherwise.
core.Token now is a non-trivial class which wraps a jax.Array. It could
be created and threaded in and out of computations to build up dependency.
The singleton object core.token has been removed, users now should create
and use fresh core.Token objects instead.
On GPU, the Threefry PRNG implementation no longer lowers to a kernel call
by default. This choice can improve runtime memory usage at a compile-time
cost. Prior behavior, which produces a kernel call, can be recovered with
jax.config.update('jax_threefry_gpu_kernel_lowering', True). If the new
default causes issues, please file a bug. Otherwise, we intend to remove
this flag in a future release.
Deprecations & Removals
Pallas now exclusively uses XLA for compiling kernels on GPU. The old
lowering pass via Triton Python APIs has been removed and the
JAX_TRITON_COMPILE_VIA_XLA environment variable no longer has any effect.
{func}jax.numpy.clip has a new argument signature: a, a_min, and
a_max are deprecated in favor of x (positional only), min, and
max ({jax-issue}20550).
The device() method of JAX arrays has been removed, after being deprecated
since JAX v0.4.21. Use arr.devices() instead.
The initial argument to {func}jax.nn.softmax and {func}jax.nn.log_softmax
is deprecated; empty inputs to softmax are now supported without setting this.
In {func}jax.jit, passing invalid static_argnums or static_argnames
now leads to an error rather than a warning.
The minimum jaxlib version is now 0.4.23.
The {func}jax.numpy.hypot function now issues a deprecation warning when
passing complex-valued inputs to it. This will raise an error when the
deprecation is completed.
Scalar arguments to {func}jax.numpy.nonzero, {func}jax.numpy.where, and
related functions now raise an error, following a similar change in NumPy.
The config option jax_cpu_enable_gloo_collectives is deprecated.
... (truncated)
Commits
500da57 Merge pull request #21077 from merrymercy:patch-1
Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting @dependabot rebase.
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it
- `@dependabot merge` will merge this PR after your CI passes on it
- `@dependabot squash and merge` will squash and merge this PR after your CI passes on it
- `@dependabot cancel merge` will cancel a previously requested merge and block automerging
- `@dependabot reopen` will reopen this PR if it is closed
- `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually
- `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Bumps jax from 0.4.26 to 0.4.27.
Changelog
Sourced from jax's changelog.
... (truncated)
Commits
500da57
Merge pull request #21077 from merrymercy:patch-170b4477
Start jax and jaxlib 0.4.27 release326adc0
[Mosaic GPU] Adjust memref.expand_shape construction to pass in the new args3e5a18f
Update XLA dependency to use revisioncb0c498
Merge pull request #21081 from hawkinsp:sourcemap4de3464
Fix that the insufficient output HBM buffer init would cause the <unk> token ...eee2783
Merge pull request #21070 from shuhand0:rel0.0.7f6d8852
Merge pull request #20327 from selamw1:add_examplesaac3679
fix jaxlib config name9caf59d
improve documentation for ix_Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting
@dependabot rebase
.Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show