asmith26 / jax_toolkit

A collection of jax functions to help with common machine/deep learning related functionality.
https://asmith26.github.io/jax_toolkit/
Apache License 2.0
5 stars 0 forks source link

Bump jax from 0.2.8 to 0.3.8 #293

Closed dependabot[bot] closed 2 years ago

dependabot[bot] commented 2 years ago

Bumps jax from 0.2.8 to 0.3.8.

Release notes

Sourced from jax's releases.

JAX release v0.3.8

  • GitHub commits.
  • Changes
    • {func}jax.numpy.linalg.svd on TPUs uses a qdwh-svd solver.
    • {func}jax.numpy.linalg.cond on TPUs now accepts complex input.
    • {func}jax.numpy.linalg.pinv on TPUs now accepts complex input.
    • {func}jax.numpy.linalg.matrix_rank on TPUs now accepts complex input.
    • {func}jax.scipy.cluster.vq.vq has been added.
    • jax.experimental.maps.mesh has been deleted. Please use jax.experimental.maps.Mesh. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh for more information.
    • {func}jax.scipy.linalg.qr now returns a length-1 tuple rather than the raw array when mode='r', in order to match the behavior of scipy.linalg.qr ({jax-issue}[#10452](https://github.com/google/jax/issues/10452))
    • {func}jax.numpy.take_along_axis now takes an optional mode parameter that specifies the behavior of out-of-bounds indexing. By default, invalid values (e.g., NaN) will be returned for out-of-bounds indices. In previous versions of JAX, invalid indices were clamped into range. The previous behavior can be restored by passing mode="clip".
    • {func}jax.numpy.take now defaults to mode="fill", which returns invalid values (e.g., NaN) for out-of-bounds indices.
    • Scatter operations, such as x.at[...].set(...), now have "drop" semantics. This has no effect on the scatter operation itself, but it means that when differentiated the gradient of a scatter will yield zero cotangents for out-of-bounds indices. Previously out-of-bounds indices were clamped into range for the gradient, which was not mathematically correct.
    • {func}jax.numpy.take_along_axis now raises a TypeError if its indices are not of an integer type, matching the behavior of {func}numpy.take_along_axis. Previously non-integer indices were silently cast to integers.
    • {func}jax.numpy.ravel_multi_index now raises a TypeError if its dims argument is not of an integer type, matching the behavior of {func}numpy.ravel_multi_index. Previously non-integer dims was silently cast to integers.
    • {func}jax.numpy.split now raises a TypeError if its axis argument is not of an integer type, matching the behavior of {func}numpy.split. Previously non-integer axis was silently cast to integers.
    • {func}jax.numpy.indices now raises a TypeError if its dimensions are not of an integer type, matching the behavior of {func}numpy.indices. Previously non-integer dimensions were silently cast to integers.
    • {func}jax.numpy.diag now raises a TypeError if its k argument is not of an integer type, matching the behavior of {func}numpy.diag. Previously non-integer k was silently cast to integers.
    • Added {func}jax.random.orthogonal.
  • Deprecations
    • Many functions and objects available in {mod}jax.test_util are now deprecated and will raise a warning on import. This includes cases_from_list, check_close, check_eq, device_under_test, format_shape_dtype_string, rand_uniform, skip_on_devices, with_config, xla_bridge, and _default_tolerance ({jax-issue}[#10389](https://github.com/google/jax/issues/10389)). These, along with previously-deprecated JaxTestCase, JaxTestLoader, and BufferDonationTestCase, will be removed in a future JAX release. Most of these utilites can be replaced by calls to standard python & numpy testing utilities found in e.g. {mod}unittest, {mod}absl.testing, {mod}numpy.testing, etc. JAX-specific functionality such as device checking can be replaced through the use of public APIs such as {func}jax.devices. Many of the deprecated utilities will still exist in {mod}jax._src.test_util, but these are not public APIs and as such may be changed or removed without notice in future releases.

JAX release v0.3.7

  • Fixed a performance problem if the indices passed to jax.numpy.take_along_axis were broadcasted (#10281).
  • jax.scipy.special.expit and jax.scipy.special.logit now require their arguments to be scalars or JAX arrays. They also now promote integer arguments to floating point.
  • The DeviceArray.tile() method is deprecated, because numpy arrays do not have a tile() method. As a replacement for this, use jax.numpy.tile (#10266).

Jaxlib v0.3.7

  • Linux wheels are now built conforming to the manylinux2014 standard, instead of manylinux2010.

JAX release v0.3.6

  • Changes:
    • Upgraded libtpu wheel to the fixed version. Fixes #10218.

JAX release v0.3.5

Changes

  • added jax.random.loggamma & improved behavior of jax.random.beta and jax.random.dirichlet for small parameter values (#9906).
  • the private lax_numpy submodule is no longer exposed in the jax.numpy namespace (#10029).
  • added array creation routines jax.numpy.frombuffer, jax.numpy.fromfunction, and jax.numpy.fromstring (#10049).
  • DeviceArray.copy() now returns a DeviceArray rather than a np.ndarray (#10069)
  • added jax.scipy.linalg.rsf2csf
  • Deprecations:
    • jax.nn.normalize is being deprecated. Use jax.nn.standardize instead (#9899).
    • jax.tree_util.tree_multimap is deprecated. Use jax.tree_util.tree_map instead (#5746).
    • jax.experimental.sharded_jit is deprecated. Use pjit instead.

... (truncated)

Changelog

Sourced from jax's changelog.

jaxlib 0.3.8 (Unreleased)

jax 0.3.8 (April 29 2022)

  • GitHub commits.
  • Changes
    • {func}jax.numpy.linalg.svd on TPUs uses a qdwh-svd solver.
    • {func}jax.numpy.linalg.cond on TPUs now accepts complex input.
    • {func}jax.numpy.linalg.pinv on TPUs now accepts complex input.
    • {func}jax.numpy.linalg.matrix_rank on TPUs now accepts complex input.
    • {func}jax.scipy.cluster.vq.vq has been added.
    • jax.experimental.maps.mesh has been deleted. Please use jax.experimental.maps.Mesh. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh for more information.
    • {func}jax.scipy.linalg.qr now returns a length-1 tuple rather than the raw array when mode='r', in order to match the behavior of scipy.linalg.qr ({jax-issue}[#10452](https://github.com/google/jax/issues/10452))
    • {func}jax.numpy.take_along_axis now takes an optional mode parameter that specifies the behavior of out-of-bounds indexing. By default, invalid values (e.g., NaN) will be returned for out-of-bounds indices. In previous versions of JAX, invalid indices were clamped into range. The previous behavior can be restored by passing mode="clip".
    • {func}jax.numpy.take now defaults to mode="fill", which returns invalid values (e.g., NaN) for out-of-bounds indices.
    • Scatter operations, such as x.at[...].set(...), now have "drop" semantics. This has no effect on the scatter operation itself, but it means that when differentiated the gradient of a scatter will yield zero cotangents for out-of-bounds indices. Previously out-of-bounds indices were clamped into range for the gradient, which was not mathematically correct.
    • {func}jax.numpy.take_along_axis now raises a TypeError if its indices are not of an integer type, matching the behavior of {func}numpy.take_along_axis. Previously non-integer indices were silently cast to integers.
    • {func}jax.numpy.ravel_multi_index now raises a TypeError if its dims argument is not of an integer type, matching the behavior of {func}numpy.ravel_multi_index. Previously non-integer dims was silently cast to integers.
    • {func}jax.numpy.split now raises a TypeError if its axis argument is not of an integer type, matching the behavior of {func}numpy.split. Previously non-integer axis was silently cast to integers.
    • {func}jax.numpy.indices now raises a TypeError if its dimensions are not of an integer type, matching the behavior of {func}numpy.indices. Previously non-integer dimensions were silently cast to integers.
    • {func}jax.numpy.diag now raises a TypeError if its k argument is not of an integer type, matching the behavior of {func}numpy.diag. Previously non-integer k was silently cast to integers.
    • Added {func}jax.random.orthogonal.
  • Deprecations

... (truncated)

Commits
  • 290c90d Trivial change for backward compatibility stub.
  • a682a37 Merge pull request #10509 from mattjj:avoid-units-linear-transpose
  • c01fed9 Merge pull request #10508 from mattjj:avoid-units-lax-reduce
  • e3a1b7f Merge pull request #10507 from mattjj:avoid-units-new-remat
  • 874374c Raise a better error when assert fails in mesh_sharding_specs
  • 0208490 [linalg] Add tpu svd lowering rule.
  • b90df4b Merge pull request #10506 from mattjj:remove-partial-eval-to-jaxpr-dynamic
  • 65bff3c [remove-units] avoid unit-generating function in jax.linear_transpose
  • f970f58 [remove-units] avoid unit-generating function in lax.reduce
  • fb4731d Merge pull request #10502 from jakevdp:fix-myst
  • Additional commits viewable in compare view


Dependabot compatibility score

Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting @dependabot rebase.


Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
dependabot[bot] commented 2 years ago

Superseded by #295.