SPU (Secure Processing Unit) aims to be a provable, measurable secure computation device, which provides computation ability while keeping your private data protected.
google/jax (jax)
### [`v0.4.30`](https://togithub.com/google/jax/blob/HEAD/CHANGELOG.md#jax-0430-June-18-2024)
[Compare Source](https://togithub.com/google/jax/compare/jax-v0.4.29...jax-v0.4.30)
- Changes
- JAX supports ml_dtypes >= 0.2. In 0.4.29 release, the ml_dtypes version was
bumped to 0.4.0 but this has been rolled back in this release to give users
of both TensorFlow and JAX more time to migrate to a newer TensorFlow
release.
- `jax.experimental.mesh_utils` can now create an efficient mesh for TPU v5e.
- jax now depends on jaxlib directly. This change was enabled by the CUDA
plugin switch: there are no longer multiple jaxlib variants. You can install
a CPU-only jax with `pip install jax`, no extras required.
- Added an API for exporting and serializing JAX functions. This used
to exist in `jax.experimental.export` (which is being deprecated),
and will now live in `jax.export`.
See the [documentation](https://jax.readthedocs.io/en/latest/export/index.html).
- Deprecations
- Internal pretty-printing tools `jax.core.pp_*` are deprecated, and will be removed
in a future release.
- Hashing of tracers is deprecated, and will lead to a `TypeError` in a future JAX
release. This previously was the case, but there was an inadvertent regression in
the last several JAX releases.
- `jax.experimental.export` is deprecated. Use {mod}`jax.export` instead.
See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export).
- Passing an array in place of a dtype is now deprecated in most cases; e.g. for arrays
`x` and `y`, `x.astype(y)` will raise a warning. To silence it use `x.astype(y.dtype)`.
- `jax.xla_computation` is deprecated and will be removed in a future release.
Please use the AOT APIs to get the same functionality as `jax.xla_computation`.
- `jax.xla_computation(fn)(*args, **kwargs)` can be replaced with
`jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')`.
- You can also use `.out_info` property of `jax.stages.Lowered` to get the
output information (like tree structure, shape and dtype).
- For cross-backend lowering, you can replace
`jax.xla_computation(fn, backend='tpu')(*args, **kwargs)` with
`jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')`.
### [`v0.4.29`](https://togithub.com/google/jax/blob/HEAD/CHANGELOG.md#jax-0429-June-10-2024)
[Compare Source](https://togithub.com/google/jax/compare/jax-v0.4.28...jax-v0.4.29)
- Changes
- We anticipate that this will be the last release of JAX and jaxlib
supporting a monolithic CUDA jaxlib. Future releases will use the CUDA
plugin jaxlib (e.g. `pip install jax[cuda12]`).
- JAX now requires ml_dtypes version 0.4.0 or newer.
- Removed backwards-compatibility support for old usage of the
`jax.experimental.export` API. It is not possible anymore to use
`from jax.experimental.export import export`, and instead you should use
`from jax.experimental import export`.
The removed functionality has been deprecated since 0.4.24.
- Added `is_leaf` argument to {func}`jax.tree.all` & {func}`jax.tree_util.tree_all`.
- Deprecations
- `jax.sharding.XLACompatibleSharding` is deprecated. Please use
`jax.sharding.Sharding`.
- `jax.experimental.Exported.in_shardings` has been renamed as
`jax.experimental.Exported.in_shardings_hlo`. Same for `out_shardings`.
The old names will be removed after 3 months.
- Removed a number of previously-deprecated APIs:
- from {mod}`jax.core`: `non_negative_dim`, `DimSize`, `Shape`
- from {mod}`jax.lax`: `tie_in`
- from {mod}`jax.nn`: `normalize`
- from {mod}`jax.interpreters.xla`: `backend_specific_translations`,
`translations`, `register_translation`, `xla_destructure`,
`TranslationRule`, `TranslationContext`, `XlaOp`.
- The `tol` argument of {func}`jax.numpy.linalg.matrix_rank` is being
deprecated and will soon be removed. Use `rtol` instead.
- The `rcond` argument of {func}`jax.numpy.linalg.pinv` is being
deprecated and will soon be removed. Use `rtol` instead.
- The deprecated `jax.config` submodule has been removed. To configure JAX
use `import jax` and then reference the config object via `jax.config`.
- {mod}`jax.random` APIs no longer accept batched keys, where previously
some did unintentionally. Going forward, we recommend explicit use of
{func}`jax.vmap` in such cases.
- In {func}`jax.scipy.special.beta`, the `x` and `y` parameters have been
renamed to `a` and `b` for consistency with other `beta` APIs.
- New Functionality
- Added {func}`jax.experimental.Exported.in_shardings_jax` to construct
shardings that can be used with the JAX APIs from the HloShardings
that are stored in the `Exported` objects.
### [`v0.4.28`](https://togithub.com/google/jax/blob/HEAD/CHANGELOG.md#jax-0428-May-9-2024)
[Compare Source](https://togithub.com/google/jax/compare/jax-v0.4.27...jax-v0.4.28)
- Bug fixes
- Reverted a change to `make_jaxpr` that was breaking Equinox ([#21116](https://togithub.com/google/jax/issues/21116)).
- Deprecations & removals
- The `kind` argument to {func}`jax.numpy.sort` and {func}`jax.numpy.argsort`
is now removed. Use `stable=True` or `stable=False` instead.
- Removed `get_compute_capability` from the `jax.experimental.pallas.gpu`
module. Use the `compute_capability` attribute of a GPU device, returned
by {func}`jax.devices` or {func}`jax.local_devices`, instead.
- The `newshape` argument to {func}`jax.numpy.reshape`is being deprecated
and will soon be removed. Use `shape` instead.
- Changes
- The minimum jaxlib version of this release is 0.4.27.
### [`v0.4.27`](https://togithub.com/google/jax/blob/HEAD/CHANGELOG.md#jax-0427-May-7-2024)
[Compare Source](https://togithub.com/google/jax/compare/jax-v0.4.26...jax-v0.4.27)
- New Functionality
- Added {func}`jax.numpy.unstack` and {func}`jax.numpy.cumulative_sum`,
following their addition in the array API 2023 standard, soon to be
adopted by NumPy.
- Added a new config option `jax_cpu_collectives_implementation` to select the
implementation of cross-process collective operations used by the CPU backend.
Choices available are `'none'`(default), `'gloo'` and `'mpi'` (requires jaxlib 0.4.26).
If set to `'none'`, cross-process collective operations are disabled.
- Changes
- {func}`jax.pure_callback`, {func}`jax.experimental.io_callback`
and {func}`jax.debug.callback` now use {class}`jax.Array` instead
of {class}`np.ndarray`. You can recover the old behavior by transforming
the arguments via `jax.tree.map(np.asarray, args)` before passing them
to the callback.
- `complex_arr.astype(bool)` now follows the same semantics as NumPy, returning
False where `complex_arr` is equal to `0 + 0j`, and True otherwise.
- `core.Token` now is a non-trivial class which wraps a `jax.Array`. It could
be created and threaded in and out of computations to build up dependency.
The singleton object `core.token` has been removed, users now should create
and use fresh `core.Token` objects instead.
- On GPU, the Threefry PRNG implementation no longer lowers to a kernel call
by default. This choice can improve runtime memory usage at a compile-time
cost. Prior behavior, which produces a kernel call, can be recovered with
`jax.config.update('jax_threefry_gpu_kernel_lowering', True)`. If the new
default causes issues, please file a bug. Otherwise, we intend to remove
this flag in a future release.
- Deprecations & Removals
- Pallas now exclusively uses XLA for compiling kernels on GPU. The old
lowering pass via Triton Python APIs has been removed and the
`JAX_TRITON_COMPILE_VIA_XLA` environment variable no longer has any effect.
- {func}`jax.numpy.clip` has a new argument signature: `a`, `a_min`, and
`a_max` are deprecated in favor of `x` (positional only), `min`, and
`max` ({jax-issue}`20550`).
- The `device()` method of JAX arrays has been removed, after being deprecated
since JAX v0.4.21. Use `arr.devices()` instead.
- The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax`
is deprecated; empty inputs to softmax are now supported without setting this.
- In {func}`jax.jit`, passing invalid `static_argnums` or `static_argnames`
now leads to an error rather than a warning.
- The minimum jaxlib version is now 0.4.23.
- The {func}`jax.numpy.hypot` function now issues a deprecation warning when
passing complex-valued inputs to it. This will raise an error when the
deprecation is completed.
- Scalar arguments to {func}`jax.numpy.nonzero`, {func}`jax.numpy.where`, and
related functions now raise an error, following a similar change in NumPy.
- The config option `jax_cpu_enable_gloo_collectives` is deprecated.
Use `jax.config.update('jax_cpu_collectives_implementation', 'gloo')` instead.
- The `jax.Array.device_buffer` and `jax.Array.device_buffers` methods have
been removed after being deprecated in JAX v0.4.22. Instead use
{attr}`jax.Array.addressable_shards` and {meth}`jax.Array.addressable_data`.
- The `condition`, `x`, and `y` parameters of `jax.numpy.where` are now
positional-only, following deprecation of the keywords in JAX v0.4.21.
- Non-array arguments to functions in {mod}`jax.lax.linalg` now must be
specified by keyword. Previously, this raised a DeprecationWarning.
- Array-like arguments are now required in several :func:`jax.numpy` APIs,
including {func}`~jax.numpy.apply_along_axis`,
{func}`~jax.numpy.apply_over_axes`, {func}`~jax.numpy.inner`,
{func}`~jax.numpy.outer`, {func}`~jax.numpy.cross`,
{func}`~jax.numpy.kron`, and {func}`~jax.numpy.lexsort`.
- Bug fixes
- {func}`jax.numpy.astype` will now always return a copy when `copy=True`.
Previously, no copy would be made when the output array would have the same
dtype as the input array. This may result in some increased memory usage.
The default value is set to `copy=False` to preserve backwards compatibility.
Configuration
📅 Schedule: Branch creation - At any time (no schedule defined), Automerge - At any time (no schedule defined).
🚦 Automerge: Disabled by config. Please merge this manually once you are satisfied.
♻ Rebasing: Whenever PR becomes conflicted, or you tick the rebase/retry checkbox.
🔕 Ignore: Close this PR and you won't be reminded about this update again.
[ ] If you want to rebase/retry this PR, check this box
This PR has been generated by Mend Renovate. View repository job log here.
This PR contains the following updates:
>=0.4.16, <=0.4.26
->>=0.4.30, <=0.4.30
Release Notes
google/jax (jax)
### [`v0.4.30`](https://togithub.com/google/jax/blob/HEAD/CHANGELOG.md#jax-0430-June-18-2024) [Compare Source](https://togithub.com/google/jax/compare/jax-v0.4.29...jax-v0.4.30) - Changes - JAX supports ml_dtypes >= 0.2. In 0.4.29 release, the ml_dtypes version was bumped to 0.4.0 but this has been rolled back in this release to give users of both TensorFlow and JAX more time to migrate to a newer TensorFlow release. - `jax.experimental.mesh_utils` can now create an efficient mesh for TPU v5e. - jax now depends on jaxlib directly. This change was enabled by the CUDA plugin switch: there are no longer multiple jaxlib variants. You can install a CPU-only jax with `pip install jax`, no extras required. - Added an API for exporting and serializing JAX functions. This used to exist in `jax.experimental.export` (which is being deprecated), and will now live in `jax.export`. See the [documentation](https://jax.readthedocs.io/en/latest/export/index.html). - Deprecations - Internal pretty-printing tools `jax.core.pp_*` are deprecated, and will be removed in a future release. - Hashing of tracers is deprecated, and will lead to a `TypeError` in a future JAX release. This previously was the case, but there was an inadvertent regression in the last several JAX releases. - `jax.experimental.export` is deprecated. Use {mod}`jax.export` instead. See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export). - Passing an array in place of a dtype is now deprecated in most cases; e.g. for arrays `x` and `y`, `x.astype(y)` will raise a warning. To silence it use `x.astype(y.dtype)`. - `jax.xla_computation` is deprecated and will be removed in a future release. Please use the AOT APIs to get the same functionality as `jax.xla_computation`. - `jax.xla_computation(fn)(*args, **kwargs)` can be replaced with `jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')`. - You can also use `.out_info` property of `jax.stages.Lowered` to get the output information (like tree structure, shape and dtype). - For cross-backend lowering, you can replace `jax.xla_computation(fn, backend='tpu')(*args, **kwargs)` with `jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')`. ### [`v0.4.29`](https://togithub.com/google/jax/blob/HEAD/CHANGELOG.md#jax-0429-June-10-2024) [Compare Source](https://togithub.com/google/jax/compare/jax-v0.4.28...jax-v0.4.29) - Changes - We anticipate that this will be the last release of JAX and jaxlib supporting a monolithic CUDA jaxlib. Future releases will use the CUDA plugin jaxlib (e.g. `pip install jax[cuda12]`). - JAX now requires ml_dtypes version 0.4.0 or newer. - Removed backwards-compatibility support for old usage of the `jax.experimental.export` API. It is not possible anymore to use `from jax.experimental.export import export`, and instead you should use `from jax.experimental import export`. The removed functionality has been deprecated since 0.4.24. - Added `is_leaf` argument to {func}`jax.tree.all` & {func}`jax.tree_util.tree_all`. - Deprecations - `jax.sharding.XLACompatibleSharding` is deprecated. Please use `jax.sharding.Sharding`. - `jax.experimental.Exported.in_shardings` has been renamed as `jax.experimental.Exported.in_shardings_hlo`. Same for `out_shardings`. The old names will be removed after 3 months. - Removed a number of previously-deprecated APIs: - from {mod}`jax.core`: `non_negative_dim`, `DimSize`, `Shape` - from {mod}`jax.lax`: `tie_in` - from {mod}`jax.nn`: `normalize` - from {mod}`jax.interpreters.xla`: `backend_specific_translations`, `translations`, `register_translation`, `xla_destructure`, `TranslationRule`, `TranslationContext`, `XlaOp`. - The `tol` argument of {func}`jax.numpy.linalg.matrix_rank` is being deprecated and will soon be removed. Use `rtol` instead. - The `rcond` argument of {func}`jax.numpy.linalg.pinv` is being deprecated and will soon be removed. Use `rtol` instead. - The deprecated `jax.config` submodule has been removed. To configure JAX use `import jax` and then reference the config object via `jax.config`. - {mod}`jax.random` APIs no longer accept batched keys, where previously some did unintentionally. Going forward, we recommend explicit use of {func}`jax.vmap` in such cases. - In {func}`jax.scipy.special.beta`, the `x` and `y` parameters have been renamed to `a` and `b` for consistency with other `beta` APIs. - New Functionality - Added {func}`jax.experimental.Exported.in_shardings_jax` to construct shardings that can be used with the JAX APIs from the HloShardings that are stored in the `Exported` objects. ### [`v0.4.28`](https://togithub.com/google/jax/blob/HEAD/CHANGELOG.md#jax-0428-May-9-2024) [Compare Source](https://togithub.com/google/jax/compare/jax-v0.4.27...jax-v0.4.28) - Bug fixes - Reverted a change to `make_jaxpr` that was breaking Equinox ([#21116](https://togithub.com/google/jax/issues/21116)). - Deprecations & removals - The `kind` argument to {func}`jax.numpy.sort` and {func}`jax.numpy.argsort` is now removed. Use `stable=True` or `stable=False` instead. - Removed `get_compute_capability` from the `jax.experimental.pallas.gpu` module. Use the `compute_capability` attribute of a GPU device, returned by {func}`jax.devices` or {func}`jax.local_devices`, instead. - The `newshape` argument to {func}`jax.numpy.reshape`is being deprecated and will soon be removed. Use `shape` instead. - Changes - The minimum jaxlib version of this release is 0.4.27. ### [`v0.4.27`](https://togithub.com/google/jax/blob/HEAD/CHANGELOG.md#jax-0427-May-7-2024) [Compare Source](https://togithub.com/google/jax/compare/jax-v0.4.26...jax-v0.4.27) - New Functionality - Added {func}`jax.numpy.unstack` and {func}`jax.numpy.cumulative_sum`, following their addition in the array API 2023 standard, soon to be adopted by NumPy. - Added a new config option `jax_cpu_collectives_implementation` to select the implementation of cross-process collective operations used by the CPU backend. Choices available are `'none'`(default), `'gloo'` and `'mpi'` (requires jaxlib 0.4.26). If set to `'none'`, cross-process collective operations are disabled. - Changes - {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` and {func}`jax.debug.callback` now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover the old behavior by transforming the arguments via `jax.tree.map(np.asarray, args)` before passing them to the callback. - `complex_arr.astype(bool)` now follows the same semantics as NumPy, returning False where `complex_arr` is equal to `0 + 0j`, and True otherwise. - `core.Token` now is a non-trivial class which wraps a `jax.Array`. It could be created and threaded in and out of computations to build up dependency. The singleton object `core.token` has been removed, users now should create and use fresh `core.Token` objects instead. - On GPU, the Threefry PRNG implementation no longer lowers to a kernel call by default. This choice can improve runtime memory usage at a compile-time cost. Prior behavior, which produces a kernel call, can be recovered with `jax.config.update('jax_threefry_gpu_kernel_lowering', True)`. If the new default causes issues, please file a bug. Otherwise, we intend to remove this flag in a future release. - Deprecations & Removals - Pallas now exclusively uses XLA for compiling kernels on GPU. The old lowering pass via Triton Python APIs has been removed and the `JAX_TRITON_COMPILE_VIA_XLA` environment variable no longer has any effect. - {func}`jax.numpy.clip` has a new argument signature: `a`, `a_min`, and `a_max` are deprecated in favor of `x` (positional only), `min`, and `max` ({jax-issue}`20550`). - The `device()` method of JAX arrays has been removed, after being deprecated since JAX v0.4.21. Use `arr.devices()` instead. - The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax` is deprecated; empty inputs to softmax are now supported without setting this. - In {func}`jax.jit`, passing invalid `static_argnums` or `static_argnames` now leads to an error rather than a warning. - The minimum jaxlib version is now 0.4.23. - The {func}`jax.numpy.hypot` function now issues a deprecation warning when passing complex-valued inputs to it. This will raise an error when the deprecation is completed. - Scalar arguments to {func}`jax.numpy.nonzero`, {func}`jax.numpy.where`, and related functions now raise an error, following a similar change in NumPy. - The config option `jax_cpu_enable_gloo_collectives` is deprecated. Use `jax.config.update('jax_cpu_collectives_implementation', 'gloo')` instead. - The `jax.Array.device_buffer` and `jax.Array.device_buffers` methods have been removed after being deprecated in JAX v0.4.22. Instead use {attr}`jax.Array.addressable_shards` and {meth}`jax.Array.addressable_data`. - The `condition`, `x`, and `y` parameters of `jax.numpy.where` are now positional-only, following deprecation of the keywords in JAX v0.4.21. - Non-array arguments to functions in {mod}`jax.lax.linalg` now must be specified by keyword. Previously, this raised a DeprecationWarning. - Array-like arguments are now required in several :func:`jax.numpy` APIs, including {func}`~jax.numpy.apply_along_axis`, {func}`~jax.numpy.apply_over_axes`, {func}`~jax.numpy.inner`, {func}`~jax.numpy.outer`, {func}`~jax.numpy.cross`, {func}`~jax.numpy.kron`, and {func}`~jax.numpy.lexsort`. - Bug fixes - {func}`jax.numpy.astype` will now always return a copy when `copy=True`. Previously, no copy would be made when the output array would have the same dtype as the input array. This may result in some increased memory usage. The default value is set to `copy=False` to preserve backwards compatibility.Configuration
📅 Schedule: Branch creation - At any time (no schedule defined), Automerge - At any time (no schedule defined).
🚦 Automerge: Disabled by config. Please merge this manually once you are satisfied.
♻ Rebasing: Whenever PR becomes conflicted, or you tick the rebase/retry checkbox.
🔕 Ignore: Close this PR and you won't be reminded about this update again.
This PR has been generated by Mend Renovate. View repository job log here.