secretflow / spu

SPU (Secure Processing Unit) aims to be a provable, measurable secure computation device, which provides computation ability while keeping your private data protected.
https://www.secretflow.org.cn/docs/spu/en/
Apache License 2.0
208 stars 95 forks source link

chore(deps): update dependency jax to >=0.4.30, <=0.4.30 #722

Open renovate[bot] opened 2 weeks ago

renovate[bot] commented 2 weeks ago

Mend Renovate

This PR contains the following updates:

Package Change Age Adoption Passing Confidence
jax >=0.4.16, <=0.4.26 -> >=0.4.30, <=0.4.30 age adoption passing confidence

Release Notes

google/jax (jax) ### [`v0.4.30`](https://togithub.com/google/jax/blob/HEAD/CHANGELOG.md#jax-0430-June-18-2024) [Compare Source](https://togithub.com/google/jax/compare/jax-v0.4.29...jax-v0.4.30) - Changes - JAX supports ml_dtypes >= 0.2. In 0.4.29 release, the ml_dtypes version was bumped to 0.4.0 but this has been rolled back in this release to give users of both TensorFlow and JAX more time to migrate to a newer TensorFlow release. - `jax.experimental.mesh_utils` can now create an efficient mesh for TPU v5e. - jax now depends on jaxlib directly. This change was enabled by the CUDA plugin switch: there are no longer multiple jaxlib variants. You can install a CPU-only jax with `pip install jax`, no extras required. - Added an API for exporting and serializing JAX functions. This used to exist in `jax.experimental.export` (which is being deprecated), and will now live in `jax.export`. See the [documentation](https://jax.readthedocs.io/en/latest/export/index.html). - Deprecations - Internal pretty-printing tools `jax.core.pp_*` are deprecated, and will be removed in a future release. - Hashing of tracers is deprecated, and will lead to a `TypeError` in a future JAX release. This previously was the case, but there was an inadvertent regression in the last several JAX releases. - `jax.experimental.export` is deprecated. Use {mod}`jax.export` instead. See the [migration guide](https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export). - Passing an array in place of a dtype is now deprecated in most cases; e.g. for arrays `x` and `y`, `x.astype(y)` will raise a warning. To silence it use `x.astype(y.dtype)`. - `jax.xla_computation` is deprecated and will be removed in a future release. Please use the AOT APIs to get the same functionality as `jax.xla_computation`. - `jax.xla_computation(fn)(*args, **kwargs)` can be replaced with `jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')`. - You can also use `.out_info` property of `jax.stages.Lowered` to get the output information (like tree structure, shape and dtype). - For cross-backend lowering, you can replace `jax.xla_computation(fn, backend='tpu')(*args, **kwargs)` with `jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')`. ### [`v0.4.29`](https://togithub.com/google/jax/blob/HEAD/CHANGELOG.md#jax-0429-June-10-2024) [Compare Source](https://togithub.com/google/jax/compare/jax-v0.4.28...jax-v0.4.29) - Changes - We anticipate that this will be the last release of JAX and jaxlib supporting a monolithic CUDA jaxlib. Future releases will use the CUDA plugin jaxlib (e.g. `pip install jax[cuda12]`). - JAX now requires ml_dtypes version 0.4.0 or newer. - Removed backwards-compatibility support for old usage of the `jax.experimental.export` API. It is not possible anymore to use `from jax.experimental.export import export`, and instead you should use `from jax.experimental import export`. The removed functionality has been deprecated since 0.4.24. - Added `is_leaf` argument to {func}`jax.tree.all` & {func}`jax.tree_util.tree_all`. - Deprecations - `jax.sharding.XLACompatibleSharding` is deprecated. Please use `jax.sharding.Sharding`. - `jax.experimental.Exported.in_shardings` has been renamed as `jax.experimental.Exported.in_shardings_hlo`. Same for `out_shardings`. The old names will be removed after 3 months. - Removed a number of previously-deprecated APIs: - from {mod}`jax.core`: `non_negative_dim`, `DimSize`, `Shape` - from {mod}`jax.lax`: `tie_in` - from {mod}`jax.nn`: `normalize` - from {mod}`jax.interpreters.xla`: `backend_specific_translations`, `translations`, `register_translation`, `xla_destructure`, `TranslationRule`, `TranslationContext`, `XlaOp`. - The `tol` argument of {func}`jax.numpy.linalg.matrix_rank` is being deprecated and will soon be removed. Use `rtol` instead. - The `rcond` argument of {func}`jax.numpy.linalg.pinv` is being deprecated and will soon be removed. Use `rtol` instead. - The deprecated `jax.config` submodule has been removed. To configure JAX use `import jax` and then reference the config object via `jax.config`. - {mod}`jax.random` APIs no longer accept batched keys, where previously some did unintentionally. Going forward, we recommend explicit use of {func}`jax.vmap` in such cases. - In {func}`jax.scipy.special.beta`, the `x` and `y` parameters have been renamed to `a` and `b` for consistency with other `beta` APIs. - New Functionality - Added {func}`jax.experimental.Exported.in_shardings_jax` to construct shardings that can be used with the JAX APIs from the HloShardings that are stored in the `Exported` objects. ### [`v0.4.28`](https://togithub.com/google/jax/blob/HEAD/CHANGELOG.md#jax-0428-May-9-2024) [Compare Source](https://togithub.com/google/jax/compare/jax-v0.4.27...jax-v0.4.28) - Bug fixes - Reverted a change to `make_jaxpr` that was breaking Equinox ([#​21116](https://togithub.com/google/jax/issues/21116)). - Deprecations & removals - The `kind` argument to {func}`jax.numpy.sort` and {func}`jax.numpy.argsort` is now removed. Use `stable=True` or `stable=False` instead. - Removed `get_compute_capability` from the `jax.experimental.pallas.gpu` module. Use the `compute_capability` attribute of a GPU device, returned by {func}`jax.devices` or {func}`jax.local_devices`, instead. - The `newshape` argument to {func}`jax.numpy.reshape`is being deprecated and will soon be removed. Use `shape` instead. - Changes - The minimum jaxlib version of this release is 0.4.27. ### [`v0.4.27`](https://togithub.com/google/jax/blob/HEAD/CHANGELOG.md#jax-0427-May-7-2024) [Compare Source](https://togithub.com/google/jax/compare/jax-v0.4.26...jax-v0.4.27) - New Functionality - Added {func}`jax.numpy.unstack` and {func}`jax.numpy.cumulative_sum`, following their addition in the array API 2023 standard, soon to be adopted by NumPy. - Added a new config option `jax_cpu_collectives_implementation` to select the implementation of cross-process collective operations used by the CPU backend. Choices available are `'none'`(default), `'gloo'` and `'mpi'` (requires jaxlib 0.4.26). If set to `'none'`, cross-process collective operations are disabled. - Changes - {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` and {func}`jax.debug.callback` now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover the old behavior by transforming the arguments via `jax.tree.map(np.asarray, args)` before passing them to the callback. - `complex_arr.astype(bool)` now follows the same semantics as NumPy, returning False where `complex_arr` is equal to `0 + 0j`, and True otherwise. - `core.Token` now is a non-trivial class which wraps a `jax.Array`. It could be created and threaded in and out of computations to build up dependency. The singleton object `core.token` has been removed, users now should create and use fresh `core.Token` objects instead. - On GPU, the Threefry PRNG implementation no longer lowers to a kernel call by default. This choice can improve runtime memory usage at a compile-time cost. Prior behavior, which produces a kernel call, can be recovered with `jax.config.update('jax_threefry_gpu_kernel_lowering', True)`. If the new default causes issues, please file a bug. Otherwise, we intend to remove this flag in a future release. - Deprecations & Removals - Pallas now exclusively uses XLA for compiling kernels on GPU. The old lowering pass via Triton Python APIs has been removed and the `JAX_TRITON_COMPILE_VIA_XLA` environment variable no longer has any effect. - {func}`jax.numpy.clip` has a new argument signature: `a`, `a_min`, and `a_max` are deprecated in favor of `x` (positional only), `min`, and `max` ({jax-issue}`20550`). - The `device()` method of JAX arrays has been removed, after being deprecated since JAX v0.4.21. Use `arr.devices()` instead. - The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax` is deprecated; empty inputs to softmax are now supported without setting this. - In {func}`jax.jit`, passing invalid `static_argnums` or `static_argnames` now leads to an error rather than a warning. - The minimum jaxlib version is now 0.4.23. - The {func}`jax.numpy.hypot` function now issues a deprecation warning when passing complex-valued inputs to it. This will raise an error when the deprecation is completed. - Scalar arguments to {func}`jax.numpy.nonzero`, {func}`jax.numpy.where`, and related functions now raise an error, following a similar change in NumPy. - The config option `jax_cpu_enable_gloo_collectives` is deprecated. Use `jax.config.update('jax_cpu_collectives_implementation', 'gloo')` instead. - The `jax.Array.device_buffer` and `jax.Array.device_buffers` methods have been removed after being deprecated in JAX v0.4.22. Instead use {attr}`jax.Array.addressable_shards` and {meth}`jax.Array.addressable_data`. - The `condition`, `x`, and `y` parameters of `jax.numpy.where` are now positional-only, following deprecation of the keywords in JAX v0.4.21. - Non-array arguments to functions in {mod}`jax.lax.linalg` now must be specified by keyword. Previously, this raised a DeprecationWarning. - Array-like arguments are now required in several :func:`jax.numpy` APIs, including {func}`~jax.numpy.apply_along_axis`, {func}`~jax.numpy.apply_over_axes`, {func}`~jax.numpy.inner`, {func}`~jax.numpy.outer`, {func}`~jax.numpy.cross`, {func}`~jax.numpy.kron`, and {func}`~jax.numpy.lexsort`. - Bug fixes - {func}`jax.numpy.astype` will now always return a copy when `copy=True`. Previously, no copy would be made when the output array would have the same dtype as the input array. This may result in some increased memory usage. The default value is set to `copy=False` to preserve backwards compatibility.

Configuration

📅 Schedule: Branch creation - At any time (no schedule defined), Automerge - At any time (no schedule defined).

🚦 Automerge: Disabled by config. Please merge this manually once you are satisfied.

Rebasing: Whenever PR becomes conflicted, or you tick the rebase/retry checkbox.

🔕 Ignore: Close this PR and you won't be reminded about this update again.



This PR has been generated by Mend Renovate. View repository job log here.