google / budoux

https://google.github.io/budoux/
Apache License 2.0
1.44k stars 32 forks source link

Bump jax from 0.4.31 to 0.4.32 #719

Closed dependabot[bot] closed 1 month ago

dependabot[bot] commented 1 month ago

Bumps jax from 0.4.31 to 0.4.32.

Release notes

Sourced from jax's releases.

JAX release v0.4.32

  • New Functionality

    • Added jax.extend.ffi.ffi_call and jax.extend.ffi.ffi_lowering to support the use of the new ffi-tutorial to interface with custom C++ and CUDA code from JAX.
  • Changes

    • jax_enable_memories flag is set to True by default.
    • jax.numpy now supports v2023.12 of the Python Array API Standard. See python-array-api for more information.
    • Computations on the CPU backend may now be dispatched asynchronously in more cases. Previously non-parallel computations were always dispatched synchronously. You can recover the old behavior by setting jax.config.update('jax_cpu_enable_async_dispatch', False).
    • Added new jax.process_indices function to replace the jax.host_ids() function that was deprecated in JAX v0.2.13.
    • To align with the behavior of numpy.fabs, jax.numpy.fabs has been modified to no longer support complex dtypes.
    • jax.tree_util.register_dataclass now checks that data_fields and meta_fields includes all dataclass fields with init=True and only them, if nodetype is a dataclass.
    • Several jax.numpy functions now have full jax.numpy.ufunc interfaces, including jax.numpy.add, jax.numpy.multiply, jax.numpy.bitwise_and, jax.numpy.bitwise_or, jax.numpy.bitwise_xor, jax.numpy.logical_and, jax.numpy.logical_and, and jax.numpy.logical_and. In future releases we plan to expand these to other ufuncs.
    • Added jax.lax.optimization_barrier, which allows users to prevent compiler optimizations such as common-subexpression elimination and to control scheduling.
  • Breaking changes

    • The MHLO MLIR dialect (jax.extend.mlir.mhlo) has been removed. Use the stablehlo dialect instead.
  • Deprecations

    • Complex inputs to jax.numpy.clip and {func}jax.numpy.hypot are no longer allowed, after being deprecated since JAX v0.4.27.
    • Deprecated the following APIs:
      • jax.lib.xla_bridge.xla_client: use jax.lib.xla_client directly.
      • jax.lib.xla_bridge.get_backend: use jax.extend.backend.get_backend.
      • jax.lib.xla_bridge.default_backend: use jax.extend.backend.default_backend.
    • The jax.experimental.array_api module is deprecated, and importing it is no longer required to use the Array API. jax.numpy supports the array API directly; see python-array-api for more information.
    • The internal utilities jax.core.check_eqn, jax.core.check_type, and jax.core.check_valid_jaxtype are now deprecated, and will be removed in the future.
    • jax.numpy.round_ has been deprecated, following removal of the corresponding

... (truncated)

Changelog

Sourced from jax's changelog.

jax 0.4.32 (September 11, 2024)

  • New Functionality

    • Added {func}jax.extend.ffi.ffi_call and {func}jax.extend.ffi.ffi_lowering to support the use of the new {ref}ffi-tutorial to interface with custom C++ and CUDA code from JAX.
  • Changes

    • jax_pmap_no_rank_reduction flag is set to True by default.
      • array[0] on a pmap result now introduces a reshape (use array[0:1] instead).
      • The per-shard shape (accessable via jax_array.addressable_shards or jax_array.addressable_data(0)) now has a leading (1, ...). Update code that directly accesses shards accordingly. The rank of the per-shard-shape now matches that of the global shape which is the same behavior as jit. This avoids costly reshapes when passing results from pmap into jit.
    • jax_enable_memories flag is set to True by default.
    • {mod}jax.numpy now supports v2023.12 of the Python Array API Standard. See {ref}python-array-api for more information.
    • Computations on the CPU backend may now be dispatched asynchronously in more cases. Previously non-parallel computations were always dispatched synchronously. You can recover the old behavior by setting jax.config.update('jax_cpu_enable_async_dispatch', False).
    • Added new {func}jax.process_indices function to replace the jax.host_ids() function that was deprecated in JAX v0.2.13.
    • To align with the behavior of numpy.fabs, jax.numpy.fabs has been modified to no longer support complex dtypes.
    • jax.tree_util.register_dataclass now checks that data_fields and meta_fields includes all dataclass fields with init=True and only them, if nodetype is a dataclass.
    • Several {mod}jax.numpy functions now have full {class}~jax.numpy.ufunc interfaces, including {obj}~jax.numpy.add, {obj}~jax.numpy.multiply, {obj}~jax.numpy.bitwise_and, {obj}~jax.numpy.bitwise_or, {obj}~jax.numpy.bitwise_xor, {obj}~jax.numpy.logical_and, {obj}~jax.numpy.logical_and, and {obj}~jax.numpy.logical_and. In future releases we plan to expand these to other ufuncs.
    • Added {func}jax.lax.optimization_barrier, which allows users to prevent compiler optimizations such as common-subexpression elimination and to control scheduling.
  • Breaking changes

    • The MHLO MLIR dialect (jax.extend.mlir.mhlo) has been removed. Use the stablehlo dialect instead.
  • Deprecations

    • Complex inputs to {func}jax.numpy.clip and {func}jax.numpy.hypot are no longer allowed, after being deprecated since JAX v0.4.27.
    • Deprecated the following APIs:
      • jax.lib.xla_bridge.xla_client: use {mod}jax.lib.xla_client directly.
      • jax.lib.xla_bridge.get_backend: use {func}jax.extend.backend.get_backend.

... (truncated)

Commits
  • 1594d2f Prepare for v0.4.32 release.
  • ed849ff Make sure to call the superclass' init() on a newly created instance in P...
  • 2bd1fde Relax test tolerance in pinv test to fix a CI failure on Windows CPU.
  • e869a9d Merge pull request #23415 from kaixih:key_value_seq_lengths
  • ea68f45 Internal change
  • 49dd6ed Disable a pallas export compatibility test that fails on TPU v6e.
  • 808003b Update users of jax.tree.map() to be more careful about how they handle Nones.
  • e3c4b20 [Pallas] Implement tiled and swizzled Memref loads for Mosaic GPU via "GPUBlo...
  • c659dc9 [Pallas] Disable win32 gpu_ops_test.
  • 14b8625 Merge pull request #23549 from pschuh:docs-update
  • Additional commits viewable in compare view


Dependabot compatibility score

Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting @dependabot rebase.


Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)