Added jax.extend.ffi.ffi_call and jax.extend.ffi.ffi_lowering
to support the use of the new ffi-tutorial to interface with custom
C++ and CUDA code from JAX.
Changes
jax_enable_memories flag is set to True by default.
jax.numpy now supports v2023.12 of the Python Array API Standard.
See python-array-api for more information.
Computations on the CPU backend may now be dispatched asynchronously in
more cases. Previously non-parallel computations were always dispatched
synchronously. You can recover the old behavior by setting
jax.config.update('jax_cpu_enable_async_dispatch', False).
Added new jax.process_indices function to replace the
jax.host_ids() function that was deprecated in JAX v0.2.13.
To align with the behavior of numpy.fabs, jax.numpy.fabs has been
modified to no longer support complex dtypes.
jax.tree_util.register_dataclass now checks that data_fields
and meta_fields includes all dataclass fields with init=True
and only them, if nodetype is a dataclass.
Several jax.numpy functions now have full jax.numpy.ufunc
interfaces, including jax.numpy.add, jax.numpy.multiply,
jax.numpy.bitwise_and, jax.numpy.bitwise_or,
jax.numpy.bitwise_xor, jax.numpy.logical_and,
jax.numpy.logical_and, and jax.numpy.logical_and.
In future releases we plan to expand these to other ufuncs.
Added jax.lax.optimization_barrier, which allows users to prevent
compiler optimizations such as common-subexpression elimination and to
control scheduling.
Breaking changes
The MHLO MLIR dialect (jax.extend.mlir.mhlo) has been removed. Use the
stablehlo dialect instead.
Deprecations
Complex inputs to jax.numpy.clip and {func}jax.numpy.hypot are
no longer allowed, after being deprecated since JAX v0.4.27.
Deprecated the following APIs:
jax.lib.xla_bridge.xla_client: use jax.lib.xla_client directly.
jax.lib.xla_bridge.get_backend: use jax.extend.backend.get_backend.
jax.lib.xla_bridge.default_backend: use jax.extend.backend.default_backend.
The jax.experimental.array_api module is deprecated, and importing it is no
longer required to use the Array API. jax.numpy supports the array API
directly; see python-array-api for more information.
The internal utilities jax.core.check_eqn, jax.core.check_type, and
jax.core.check_valid_jaxtype are now deprecated, and will be removed in
the future.
jax.numpy.round_ has been deprecated, following removal of the corresponding
Added {func}jax.extend.ffi.ffi_call and {func}jax.extend.ffi.ffi_lowering
to support the use of the new {ref}ffi-tutorial to interface with custom
C++ and CUDA code from JAX.
Changes
jax_pmap_no_rank_reduction flag is set to True by default.
array[0] on a pmap result now introduces a reshape (use array[0:1]
instead).
The per-shard shape (accessable via jax_array.addressable_shards or
jax_array.addressable_data(0)) now has a leading (1, ...). Update code
that directly accesses shards accordingly. The rank of the per-shard-shape
now matches that of the global shape which is the same behavior as jit.
This avoids costly reshapes when passing results from pmap into jit.
jax_enable_memories flag is set to True by default.
{mod}jax.numpy now supports v2023.12 of the Python Array API Standard.
See {ref}python-array-api for more information.
Computations on the CPU backend may now be dispatched asynchronously in
more cases. Previously non-parallel computations were always dispatched
synchronously. You can recover the old behavior by setting
jax.config.update('jax_cpu_enable_async_dispatch', False).
Added new {func}jax.process_indices function to replace the
jax.host_ids() function that was deprecated in JAX v0.2.13.
To align with the behavior of numpy.fabs, jax.numpy.fabs has been
modified to no longer support complex dtypes.
jax.tree_util.register_dataclass now checks that data_fields
and meta_fields includes all dataclass fields with init=True
and only them, if nodetype is a dataclass.
Several {mod}jax.numpy functions now have full {class}~jax.numpy.ufunc
interfaces, including {obj}~jax.numpy.add, {obj}~jax.numpy.multiply,
{obj}~jax.numpy.bitwise_and, {obj}~jax.numpy.bitwise_or,
{obj}~jax.numpy.bitwise_xor, {obj}~jax.numpy.logical_and,
{obj}~jax.numpy.logical_and, and {obj}~jax.numpy.logical_and.
In future releases we plan to expand these to other ufuncs.
Added {func}jax.lax.optimization_barrier, which allows users to prevent
compiler optimizations such as common-subexpression elimination and to
control scheduling.
Breaking changes
The MHLO MLIR dialect (jax.extend.mlir.mhlo) has been removed. Use the
stablehlo dialect instead.
Deprecations
Complex inputs to {func}jax.numpy.clip and {func}jax.numpy.hypot are
no longer allowed, after being deprecated since JAX v0.4.27.
Deprecated the following APIs:
jax.lib.xla_bridge.xla_client: use {mod}jax.lib.xla_client directly.
jax.lib.xla_bridge.get_backend: use {func}jax.extend.backend.get_backend.
Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting @dependabot rebase.
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it
- `@dependabot merge` will merge this PR after your CI passes on it
- `@dependabot squash and merge` will squash and merge this PR after your CI passes on it
- `@dependabot cancel merge` will cancel a previously requested merge and block automerging
- `@dependabot reopen` will reopen this PR if it is closed
- `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually
- `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Bumps jax from 0.4.31 to 0.4.32.
Release notes
Sourced from jax's releases.
... (truncated)
Changelog
Sourced from jax's changelog.
... (truncated)
Commits
1594d2f
Prepare for v0.4.32 release.ed849ff
Make sure to call the superclass' init() on a newly created instance in P...2bd1fde
Relax test tolerance in pinv test to fix a CI failure on Windows CPU.e869a9d
Merge pull request #23415 from kaixih:key_value_seq_lengthsea68f45
Internal change49dd6ed
Disable a pallas export compatibility test that fails on TPU v6e.808003b
Update users of jax.tree.map() to be more careful about how they handle Nones.e3c4b20
[Pallas] Implement tiled and swizzled Memref loads for Mosaic GPU via "GPUBlo...c659dc9
[Pallas] Disable win32 gpu_ops_test.14b8625
Merge pull request #23549 from pschuh:docs-updateDependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting
@dependabot rebase
.Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show