asmith26 / jax_toolkit

A collection of jax functions to help with common machine/deep learning related functionality.
https://asmith26.github.io/jax_toolkit/
Apache License 2.0
5 stars 0 forks source link

Bump jax from 0.2.8 to 0.3.10 #297

Closed dependabot[bot] closed 2 years ago

dependabot[bot] commented 2 years ago

Bumps jax from 0.2.8 to 0.3.10.

Release notes

Sourced from jax's releases.

Jax release v0.3.10

No release notes provided.

Jaxlib release v0.3.10

No release notes provided.

Jax release 0.3.9

  • Changes
    • Added support for fully asynchronous checkpointing for GlobalDeviceArray.

JAX release v0.3.8

  • GitHub commits.
  • Changes
    • {func}jax.numpy.linalg.svd on TPUs uses a qdwh-svd solver.
    • {func}jax.numpy.linalg.cond on TPUs now accepts complex input.
    • {func}jax.numpy.linalg.pinv on TPUs now accepts complex input.
    • {func}jax.numpy.linalg.matrix_rank on TPUs now accepts complex input.
    • {func}jax.scipy.cluster.vq.vq has been added.
    • jax.experimental.maps.mesh has been deleted. Please use jax.experimental.maps.Mesh. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh for more information.
    • {func}jax.scipy.linalg.qr now returns a length-1 tuple rather than the raw array when mode='r', in order to match the behavior of scipy.linalg.qr ({jax-issue}[#10452](https://github.com/google/jax/issues/10452))
    • {func}jax.numpy.take_along_axis now takes an optional mode parameter that specifies the behavior of out-of-bounds indexing. By default, invalid values (e.g., NaN) will be returned for out-of-bounds indices. In previous versions of JAX, invalid indices were clamped into range. The previous behavior can be restored by passing mode="clip".
    • {func}jax.numpy.take now defaults to mode="fill", which returns invalid values (e.g., NaN) for out-of-bounds indices.
    • Scatter operations, such as x.at[...].set(...), now have "drop" semantics. This has no effect on the scatter operation itself, but it means that when differentiated the gradient of a scatter will yield zero cotangents for out-of-bounds indices. Previously out-of-bounds indices were clamped into range for the gradient, which was not mathematically correct.
    • {func}jax.numpy.take_along_axis now raises a TypeError if its indices are not of an integer type, matching the behavior of {func}numpy.take_along_axis. Previously non-integer indices were silently cast to integers.
    • {func}jax.numpy.ravel_multi_index now raises a TypeError if its dims argument is not of an integer type, matching the behavior of {func}numpy.ravel_multi_index. Previously non-integer dims was silently cast to integers.
    • {func}jax.numpy.split now raises a TypeError if its axis argument is not of an integer type, matching the behavior of {func}numpy.split. Previously non-integer axis was silently cast to integers.
    • {func}jax.numpy.indices now raises a TypeError if its dimensions are not of an integer type, matching the behavior of {func}numpy.indices. Previously non-integer dimensions were silently cast to integers.
    • {func}jax.numpy.diag now raises a TypeError if its k argument is not of an integer type, matching the behavior of {func}numpy.diag. Previously non-integer k was silently cast to integers.
    • Added {func}jax.random.orthogonal.
  • Deprecations
    • Many functions and objects available in {mod}jax.test_util are now deprecated and will raise a warning on import. This includes cases_from_list, check_close, check_eq, device_under_test, format_shape_dtype_string, rand_uniform, skip_on_devices, with_config, xla_bridge, and _default_tolerance ({jax-issue}[#10389](https://github.com/google/jax/issues/10389)). These, along with previously-deprecated JaxTestCase, JaxTestLoader, and BufferDonationTestCase, will be removed in a future JAX release. Most of these utilites can be replaced by calls to standard python & numpy testing utilities found in e.g. {mod}unittest, {mod}absl.testing, {mod}numpy.testing, etc. JAX-specific functionality such as device checking can be replaced through the use of public APIs such as {func}jax.devices. Many of the deprecated utilities will still exist in {mod}jax._src.test_util, but these are not public APIs and as such may be changed or removed without notice in future releases.

JAX release v0.3.7

  • Fixed a performance problem if the indices passed to jax.numpy.take_along_axis were broadcasted (#10281).
  • jax.scipy.special.expit and jax.scipy.special.logit now require their arguments to be scalars or JAX arrays. They also now promote integer arguments to floating point.
  • The DeviceArray.tile() method is deprecated, because numpy arrays do not have a tile() method. As a replacement for this, use jax.numpy.tile (#10266).

Jaxlib v0.3.7

  • Linux wheels are now built conforming to the manylinux2014 standard, instead of manylinux2010.

JAX release v0.3.6

  • Changes:
    • Upgraded libtpu wheel to the fixed version. Fixes #10218.

JAX release v0.3.5

Changes

  • added jax.random.loggamma & improved behavior of jax.random.beta

... (truncated)

Changelog

Sourced from jax's changelog.

jax 0.3.10 (May 3, 2022)

jaxlib 0.3.10 (May 3, 2022)

  • GitHub commits.
  • Changes
    • TF commit fixes an issue in the MHLO canonicalizer that caused constant folding to take a long time or crash for certain programs.

jax 0.3.9 (May 2, 2022)

  • GitHub commits.
  • Changes
    • Added support for fully asynchronous checkpointing for GlobalDeviceArray.

jax 0.3.8 (April 29 2022)

  • GitHub commits.
  • Changes
    • {func}jax.numpy.linalg.svd on TPUs uses a qdwh-svd solver.
    • {func}jax.numpy.linalg.cond on TPUs now accepts complex input.
    • {func}jax.numpy.linalg.pinv on TPUs now accepts complex input.
    • {func}jax.numpy.linalg.matrix_rank on TPUs now accepts complex input.
    • {func}jax.scipy.cluster.vq.vq has been added.
    • jax.experimental.maps.mesh has been deleted. Please use jax.experimental.maps.Mesh. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh for more information.
    • {func}jax.scipy.linalg.qr now returns a length-1 tuple rather than the raw array when mode='r', in order to match the behavior of scipy.linalg.qr ({jax-issue}[#10452](https://github.com/google/jax/issues/10452))
    • {func}jax.numpy.take_along_axis now takes an optional mode parameter that specifies the behavior of out-of-bounds indexing. By default, invalid values (e.g., NaN) will be returned for out-of-bounds indices. In previous versions of JAX, invalid indices were clamped into range. The previous behavior can be restored by passing mode="clip".
    • {func}jax.numpy.take now defaults to mode="fill", which returns invalid values (e.g., NaN) for out-of-bounds indices.
    • Scatter operations, such as x.at[...].set(...), now have "drop" semantics. This has no effect on the scatter operation itself, but it means that when differentiated the gradient of a scatter will yield zero cotangents for out-of-bounds indices. Previously out-of-bounds indices were clamped into range for the gradient, which was not mathematically correct.
    • {func}jax.numpy.take_along_axis now raises a TypeError if its indices are not of an integer type, matching the behavior of {func}numpy.take_along_axis. Previously non-integer indices were silently cast to integers.
    • {func}jax.numpy.ravel_multi_index now raises a TypeError if its dims argument is not of an integer type, matching the behavior of {func}numpy.ravel_multi_index. Previously non-integer dims was silently cast to integers.
    • {func}jax.numpy.split now raises a TypeError if its axis argument is not of an integer type, matching the behavior of

... (truncated)

Commits
  • 38ce6d0 Update TF commit for release
  • 7297115 Merge pull request #10546 from jakevdp:unravel-indices
  • 3c2d2b2 jnp.unravel_index: improve test coverage
  • 58320e2 jnp.unravel_index: avoid overflow for large dimension sizes
  • a8c6742 Restrict Bazel visibility of //jaxlib:gpu_support
  • 03c8020 Merge pull request #10532 from mattjj:remove-units-final
  • 8c6f916 Reapply google/jax#10482 now the TF PR is ready for s...
  • 9cd55a2 [remove-units] remove units
  • 97b7fd7 Merge pull request #10564 from mattjj:gate-xla-version
  • 0c5864a add xla_client._version checks for mhlo.ConstOp signature
  • Additional commits viewable in compare view


Dependabot compatibility score

Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting @dependabot rebase.


Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
dependabot[bot] commented 2 years ago

Superseded by #299.