google/jax
### [`v0.4.13`](https://togithub.com/google/jax/blob/HEAD/CHANGELOG.md#jax-0413-June-22-2023)
[Compare Source](https://togithub.com/google/jax/compare/jax-v0.4.12...jax-v0.4.13)
- Changes
- `jax.jit` now allows `None` to be passed to `in_shardings` and
`out_shardings`. The semantics are as follows:
- For in_shardings, JAX will mark is as replicated but this behavior
can change in the future.
- For out_shardings, we will rely on the XLA GSPMD partitioner to
determine the output shardings.
- `jax.experimental.pjit.pjit` also allows `None` to be passed to
`in_shardings` and `out_shardings`. The semantics are as follows:
- If the mesh context manager is *not* provided, JAX has the freedom to
choose whatever sharding it wants.
- For in_shardings, JAX will mark is as replicated but this behavior
can change in the future.
- For out_shardings, we will rely on the XLA GSPMD partitioner to
determine the output shardings.
- If the mesh context manager is provided, None will imply that the value
will be replicated on all devices of the mesh.
- Executable.cost_analysis() works on Cloud TPU
- Added a warning if a non-allowlisted `jaxlib` plugin is in use.
- Added `jax.tree_util.tree_leaves_with_path`.
- Bug fixes
- Fixed incorrect wheel name in CUDA 12 releases ([#16362](https://togithub.com/google/jax/issues/16362)); the correct wheel
is named `cudnn89` instead of `cudnn88`.
- Deprecations
- The `native_serialization_strict_checks` parameter to
{func}`jax.experimental.jax2tf.convert` is deprecated in favor of the
new `native_serializaation_disabled_checks` ({jax-issue}`#16347`).
### [`v0.4.12`](https://togithub.com/google/jax/blob/HEAD/CHANGELOG.md#jax-0412-June-8-2023)
[Compare Source](https://togithub.com/google/jax/compare/jax-v0.4.11...jax-v0.4.12)
- Changes
- Added {class}`scipy.spatial.transform.Rotation` and {class}`scipy.spatial.transform.Slerp`
- Deprecations
- `jax.abstract_arrays` and its contents are now deprecated. See related
functionality in :mod:`jax.core`.
- `jax.numpy.alltrue`: use `jax.numpy.all`. This follows the deprecation
of `numpy.alltrue` in NumPy version 1.25.0.
- `jax.numpy.sometrue`: use `jax.numpy.any`. This follows the deprecation
of `numpy.sometrue` in NumPy version 1.25.0.
- `jax.numpy.product`: use `jax.numpy.prod`. This follows the deprecation
of `numpy.product` in NumPy version 1.25.0.
- `jax.numpy.cumproduct`: use `jax.numpy.cumprod`. This follows the deprecation
of `numpy.cumproduct` in NumPy version 1.25.0.
- `jax.sharding.OpShardingSharding` has been removed since it has been 3
months since it was deprecated.
### [`v0.4.11`](https://togithub.com/google/jax/blob/HEAD/CHANGELOG.md#jax-0411-May-31-2023)
[Compare Source](https://togithub.com/google/jax/compare/jax-v0.4.10...jax-v0.4.11)
- Deprecations
- The following APIs have been removed after a 3 month deprecation period, in
accordance with the {ref}`api-compatibility` policy:
- `jax.experimental.PartitionSpec`: use `jax.sharding.PartitionSpec`.
- `jax.experimental.maps.Mesh`: use `jax.sharding.Mesh`
- `jax.experimental.pjit.NamedSharding`: use `jax.sharding.NamedSharding`.
- `jax.experimental.pjit.PartitionSpec`: use `jax.sharding.PartitionSpec`.
- `jax.experimental.pjit.FROM_GDA`. Instead pass sharded `jax.Array` objects
as input and remove the optional `in_shardings` argument to `pjit`.
- `jax.interpreters.pxla.PartitionSpec`: use `jax.sharding.PartitionSpec`.
- `jax.interpreters.pxla.Mesh`: use `jax.sharding.Mesh`
- `jax.interpreters.xla.Buffer`: use `jax.Array`.
- `jax.interpreters.xla.Device`: use `jax.Device`.
- `jax.interpreters.xla.DeviceArray`: use `jax.Array`.
- `jax.interpreters.xla.device_put`: use `jax.device_put`.
- `jax.interpreters.xla.xla_call_p`: use `jax.experimental.pjit.pjit_p`.
- `axis_resources` argument of `with_sharding_constraint` is removed. Please
use `shardings` instead.
### [`v0.4.10`](https://togithub.com/google/jax/blob/HEAD/CHANGELOG.md#jax-0410-May-11-2023)
[Compare Source](https://togithub.com/google/jax/compare/jax-v0.4.9...jax-v0.4.10)
### [`v0.4.9`](https://togithub.com/google/jax/blob/HEAD/CHANGELOG.md#jax-049-May-9-2023)
[Compare Source](https://togithub.com/google/jax/compare/jax-v0.4.8...jax-v0.4.9)
- Changes
- The flags experimental_cpp_jit, experimental_cpp_pjit and
experimental_cpp_pmap have been removed.
They are now always on.
- Accuracy of singular value decomposition (SVD) on TPU has been improved
(requires jaxlib 0.4.9).
- Deprecations
- `jax.experimental.gda_serialization` is deprecated and has been renamed to
`jax.experimental.array_serialization`.
Please change your imports to use `jax.experimental.array_serialization`.
- The `in_axis_resources` and `out_axis_resources` arguments of pjit have been
deprecated. Please use `in_shardings` and `out_shardings` respectively.
- The function `jax.numpy.msort` has been removed. It has been deprecated since
JAX v0.4.1. Use `jnp.sort(a, axis=0)` instead.
- `in_parts` and `out_parts` arguments have been removed from `jax.xla_computation`
since they were only used with sharded_jit and sharded_jit is long gone.
- `instantiate_const_outputs` argument has been removed from `jax.xla_computation`
since it has been unused for a very long time.
### [`v0.4.8`](https://togithub.com/google/jax/blob/HEAD/CHANGELOG.md#jax-048-March-29-2023)
[Compare Source](https://togithub.com/google/jax/compare/jax-v0.4.7...jax-v0.4.8)
- Breaking changes
- A major component of the Cloud TPU runtime has been upgraded. This enables
the following new features on Cloud TPU:
- {func}`jax.debug.print`, {func}`jax.debug.callback`, and
{func}`jax.debug.breakpoint()` now work on Cloud TPU
- Automatic TPU memory defragmentation
{func}`jax.experimental.host_callback` is no longer supported on Cloud TPU
with the new runtime component. Please file an issue on the [JAX issue
tracker](https://togithub.com/google/jax/issues) if the new `jax.debug` APIs
are insufficient for your use case.
The old runtime component will be available for at least the next three
months by setting the environment variable
`JAX_USE_PJRT_C_API_ON_TPU=false`. If you find you need to disable the new
runtime for any reason, please let us know on the [JAX issue
tracker](https://togithub.com/google/jax/issues).
- Changes
- The minimum jaxlib version has been bumped from 0.4.6 to 0.4.7.
- Deprecations
- CUDA 11.4 support has been dropped. JAX GPU wheels only support
CUDA 11.8 and CUDA 12. Older CUDA versions may work if jaxlib is built
from source.
- `global_arg_shapes` argument of pmap only worked with sharded_jit and has
been removed from pmap. Please migrate to pjit and remove global_arg_shapes
from pmap.
### [`v0.4.7`](https://togithub.com/google/jax/blob/HEAD/CHANGELOG.md#jax-047-March-27-2023)
[Compare Source](https://togithub.com/google/jax/compare/jax-v0.4.6...jax-v0.4.7)
- Changes
- As per https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration
`jax.config.jax_array` cannot be disabled anymore.
- `jax.config.jax_jit_pjit_api_merge` cannot be disabled anymore.
- {func}`jax.experimental.jax2tf.convert` now supports the `native_serialization`
parameter to use JAX's native lowering to StableHLO to obtain a
StableHLO module for the entire JAX function instead of lowering each JAX
primitive to a TensorFlow op. This simplifies the internals and increases
the confidence that what you serialize matches the JAX native semantics.
See [documentation](https://togithub.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
As part of this change the config flag `--jax2tf_default_experimental_native_lowering`
has been renamed to `--jax2tf_native_serialization`.
- JAX now depends on `ml_dtypes`, which contains definitions of NumPy types
like bfloat16. These definitions were previously internal to JAX, but have
been split into a separate package to facilitate sharing them with other
projects.
- JAX now requires NumPy 1.21 or newer and SciPy 1.7 or newer.
- Deprecations
- The type `jax.numpy.DeviceArray` is deprecated. Use `jax.Array` instead,
for which it is an alias.
- The type `jax.interpreters.pxla.ShardedDeviceArray` is deprecated. Use
`jax.Array` instead.
- Passing additional arguments to {func}`jax.numpy.ndarray.at` by position is deprecated.
For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)`
- `jax.interpreters.xla.device_put` is deprecated. Please use `jax.device_put`.
- `jax.interpreters.pxla.device_put` is deprecated. Please use `jax.device_put`.
- `jax.experimental.pjit.FROM_GDA` is deprecated. Please pass in sharded
jax.Arrays as input and remove the `in_shardings` argument to pjit since
it is optional.
Configuration
đ Schedule: Branch creation - "before 4am on Monday" (UTC), Automerge - At any time (no schedule defined).
đŠ Automerge: Disabled by config. Please merge this manually once you are satisfied.
â» Rebasing: Whenever PR becomes conflicted, or you tick the rebase/retry checkbox.
đ Ignore: Close this PR and you won't be reminded about this update again.
[ ] If you want to rebase/retry this PR, check this box
This PR has been generated by Mend Renovate. View repository job log here.
This PR contains the following updates:
0.4.6
->0.4.13
Release Notes
google/jax
### [`v0.4.13`](https://togithub.com/google/jax/blob/HEAD/CHANGELOG.md#jax-0413-June-22-2023) [Compare Source](https://togithub.com/google/jax/compare/jax-v0.4.12...jax-v0.4.13) - Changes - `jax.jit` now allows `None` to be passed to `in_shardings` and `out_shardings`. The semantics are as follows: - For in_shardings, JAX will mark is as replicated but this behavior can change in the future. - For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings. - `jax.experimental.pjit.pjit` also allows `None` to be passed to `in_shardings` and `out_shardings`. The semantics are as follows: - If the mesh context manager is *not* provided, JAX has the freedom to choose whatever sharding it wants. - For in_shardings, JAX will mark is as replicated but this behavior can change in the future. - For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings. - If the mesh context manager is provided, None will imply that the value will be replicated on all devices of the mesh. - Executable.cost_analysis() works on Cloud TPU - Added a warning if a non-allowlisted `jaxlib` plugin is in use. - Added `jax.tree_util.tree_leaves_with_path`. - Bug fixes - Fixed incorrect wheel name in CUDA 12 releases ([#16362](https://togithub.com/google/jax/issues/16362)); the correct wheel is named `cudnn89` instead of `cudnn88`. - Deprecations - The `native_serialization_strict_checks` parameter to {func}`jax.experimental.jax2tf.convert` is deprecated in favor of the new `native_serializaation_disabled_checks` ({jax-issue}`#16347`). ### [`v0.4.12`](https://togithub.com/google/jax/blob/HEAD/CHANGELOG.md#jax-0412-June-8-2023) [Compare Source](https://togithub.com/google/jax/compare/jax-v0.4.11...jax-v0.4.12) - Changes - Added {class}`scipy.spatial.transform.Rotation` and {class}`scipy.spatial.transform.Slerp` - Deprecations - `jax.abstract_arrays` and its contents are now deprecated. See related functionality in :mod:`jax.core`. - `jax.numpy.alltrue`: use `jax.numpy.all`. This follows the deprecation of `numpy.alltrue` in NumPy version 1.25.0. - `jax.numpy.sometrue`: use `jax.numpy.any`. This follows the deprecation of `numpy.sometrue` in NumPy version 1.25.0. - `jax.numpy.product`: use `jax.numpy.prod`. This follows the deprecation of `numpy.product` in NumPy version 1.25.0. - `jax.numpy.cumproduct`: use `jax.numpy.cumprod`. This follows the deprecation of `numpy.cumproduct` in NumPy version 1.25.0. - `jax.sharding.OpShardingSharding` has been removed since it has been 3 months since it was deprecated. ### [`v0.4.11`](https://togithub.com/google/jax/blob/HEAD/CHANGELOG.md#jax-0411-May-31-2023) [Compare Source](https://togithub.com/google/jax/compare/jax-v0.4.10...jax-v0.4.11) - Deprecations - The following APIs have been removed after a 3 month deprecation period, in accordance with the {ref}`api-compatibility` policy: - `jax.experimental.PartitionSpec`: use `jax.sharding.PartitionSpec`. - `jax.experimental.maps.Mesh`: use `jax.sharding.Mesh` - `jax.experimental.pjit.NamedSharding`: use `jax.sharding.NamedSharding`. - `jax.experimental.pjit.PartitionSpec`: use `jax.sharding.PartitionSpec`. - `jax.experimental.pjit.FROM_GDA`. Instead pass sharded `jax.Array` objects as input and remove the optional `in_shardings` argument to `pjit`. - `jax.interpreters.pxla.PartitionSpec`: use `jax.sharding.PartitionSpec`. - `jax.interpreters.pxla.Mesh`: use `jax.sharding.Mesh` - `jax.interpreters.xla.Buffer`: use `jax.Array`. - `jax.interpreters.xla.Device`: use `jax.Device`. - `jax.interpreters.xla.DeviceArray`: use `jax.Array`. - `jax.interpreters.xla.device_put`: use `jax.device_put`. - `jax.interpreters.xla.xla_call_p`: use `jax.experimental.pjit.pjit_p`. - `axis_resources` argument of `with_sharding_constraint` is removed. Please use `shardings` instead. ### [`v0.4.10`](https://togithub.com/google/jax/blob/HEAD/CHANGELOG.md#jax-0410-May-11-2023) [Compare Source](https://togithub.com/google/jax/compare/jax-v0.4.9...jax-v0.4.10) ### [`v0.4.9`](https://togithub.com/google/jax/blob/HEAD/CHANGELOG.md#jax-049-May-9-2023) [Compare Source](https://togithub.com/google/jax/compare/jax-v0.4.8...jax-v0.4.9) - Changes - The flags experimental_cpp_jit, experimental_cpp_pjit and experimental_cpp_pmap have been removed. They are now always on. - Accuracy of singular value decomposition (SVD) on TPU has been improved (requires jaxlib 0.4.9). - Deprecations - `jax.experimental.gda_serialization` is deprecated and has been renamed to `jax.experimental.array_serialization`. Please change your imports to use `jax.experimental.array_serialization`. - The `in_axis_resources` and `out_axis_resources` arguments of pjit have been deprecated. Please use `in_shardings` and `out_shardings` respectively. - The function `jax.numpy.msort` has been removed. It has been deprecated since JAX v0.4.1. Use `jnp.sort(a, axis=0)` instead. - `in_parts` and `out_parts` arguments have been removed from `jax.xla_computation` since they were only used with sharded_jit and sharded_jit is long gone. - `instantiate_const_outputs` argument has been removed from `jax.xla_computation` since it has been unused for a very long time. ### [`v0.4.8`](https://togithub.com/google/jax/blob/HEAD/CHANGELOG.md#jax-048-March-29-2023) [Compare Source](https://togithub.com/google/jax/compare/jax-v0.4.7...jax-v0.4.8) - Breaking changes - A major component of the Cloud TPU runtime has been upgraded. This enables the following new features on Cloud TPU: - {func}`jax.debug.print`, {func}`jax.debug.callback`, and {func}`jax.debug.breakpoint()` now work on Cloud TPU - Automatic TPU memory defragmentation {func}`jax.experimental.host_callback` is no longer supported on Cloud TPU with the new runtime component. Please file an issue on the [JAX issue tracker](https://togithub.com/google/jax/issues) if the new `jax.debug` APIs are insufficient for your use case. The old runtime component will be available for at least the next three months by setting the environment variable `JAX_USE_PJRT_C_API_ON_TPU=false`. If you find you need to disable the new runtime for any reason, please let us know on the [JAX issue tracker](https://togithub.com/google/jax/issues). - Changes - The minimum jaxlib version has been bumped from 0.4.6 to 0.4.7. - Deprecations - CUDA 11.4 support has been dropped. JAX GPU wheels only support CUDA 11.8 and CUDA 12. Older CUDA versions may work if jaxlib is built from source. - `global_arg_shapes` argument of pmap only worked with sharded_jit and has been removed from pmap. Please migrate to pjit and remove global_arg_shapes from pmap. ### [`v0.4.7`](https://togithub.com/google/jax/blob/HEAD/CHANGELOG.md#jax-047-March-27-2023) [Compare Source](https://togithub.com/google/jax/compare/jax-v0.4.6...jax-v0.4.7) - Changes - As per https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration `jax.config.jax_array` cannot be disabled anymore. - `jax.config.jax_jit_pjit_api_merge` cannot be disabled anymore. - {func}`jax.experimental.jax2tf.convert` now supports the `native_serialization` parameter to use JAX's native lowering to StableHLO to obtain a StableHLO module for the entire JAX function instead of lowering each JAX primitive to a TensorFlow op. This simplifies the internals and increases the confidence that what you serialize matches the JAX native semantics. See [documentation](https://togithub.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). As part of this change the config flag `--jax2tf_default_experimental_native_lowering` has been renamed to `--jax2tf_native_serialization`. - JAX now depends on `ml_dtypes`, which contains definitions of NumPy types like bfloat16. These definitions were previously internal to JAX, but have been split into a separate package to facilitate sharing them with other projects. - JAX now requires NumPy 1.21 or newer and SciPy 1.7 or newer. - Deprecations - The type `jax.numpy.DeviceArray` is deprecated. Use `jax.Array` instead, for which it is an alias. - The type `jax.interpreters.pxla.ShardedDeviceArray` is deprecated. Use `jax.Array` instead. - Passing additional arguments to {func}`jax.numpy.ndarray.at` by position is deprecated. For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)` - `jax.interpreters.xla.device_put` is deprecated. Please use `jax.device_put`. - `jax.interpreters.pxla.device_put` is deprecated. Please use `jax.device_put`. - `jax.experimental.pjit.FROM_GDA` is deprecated. Please pass in sharded jax.Arrays as input and remove the `in_shardings` argument to pjit since it is optional.Configuration
đ Schedule: Branch creation - "before 4am on Monday" (UTC), Automerge - At any time (no schedule defined).
đŠ Automerge: Disabled by config. Please merge this manually once you are satisfied.
â» Rebasing: Whenever PR becomes conflicted, or you tick the rebase/retry checkbox.
đ Ignore: Close this PR and you won't be reminded about this update again.
This PR has been generated by Mend Renovate. View repository job log here.