jax.tree_util now contain a set of APIs that allow user to define keys for their
custom pytree node. This includes:
tree_flatten_with_path that flattens a tree and return not only each leaf but
also their key paths.
tree_map_with_paths that can map a function that takes the key path as argument.
`register_pytree_with_keys`` to register how the key path and leaves should looks
like in a custom pytree node.
keystr that pretty-prints a key path.
{func}jax2tf.call_tf has a new parameter output_shape_dtype (default None)
that can be used to declare the output shape and type of the result. This enables
{func}jax2tf.call_tf to work in the presence of shape polymorphism. ({jax-issue}[#14734](https://github.com/google/jax/issues/14734)).
Deprecations
The old key-path APIs in jax.tree_util are deprecated and will be removed 3 months
from Mar 10 2023:
register_keypaths: use {func}jax.tree_util.register_pytree_with_keys instead.
AttributeKeyPathEntry : use GetAttrKey instead.
GetitemKeyPathEntry : use SequenceKey or DictKey instead.
jaxlib 0.4.6 (Mar 9, 2023)
jax 0.4.5 (Mar 2, 2023)
Deprecations
jax.sharding.OpShardingSharding has been renamed to jax.sharding.GSPMDSharding.
jax.sharding.OpShardingSharding will be removed in 3 months from Feb 17, 2023.
The following jax.Array methods are deprecated and will be removed 3 months from
Feb 23 2023:
jax.Array.broadcast: use {func}jax.lax.broadcast instead.
jax.Array.broadcast_in_dim: use {func}jax.lax.broadcast_in_dim instead.
jax.Array.split: use {func}jax.numpy.split instead.
jax 0.4.4 (Feb 16, 2023)
Changes
The implementation of jit and pjit has been merged. Merging jit and pjit
changes the internals of JAX without affecting the public API of JAX.
Before, jit was a final style primitive. Final style means that the creation
of jaxpr was delayed as much as possible and transformations were stacked
on top of each other. With the jit-pjit implementation merge, jit
becomes an initial style primitive which means that we trace to jaxpr
as early as possible. For more information see
this section in autodidax.
Moving to initial style should simplify JAX's internals and make
development of features like dynamic shapes, etc easier.
You can disable it only via the environment variable i.e.
Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting @dependabot rebase.
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it
- `@dependabot merge` will merge this PR after your CI passes on it
- `@dependabot squash and merge` will squash and merge this PR after your CI passes on it
- `@dependabot cancel merge` will cancel a previously requested merge and block automerging
- `@dependabot reopen` will reopen this PR if it is closed
- `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually
- `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Updates the requirements on jax[cpu] to permit the latest version.
Changelog
Sourced from jax[cpu]'s changelog.
... (truncated)
Commits
f1f4840
Merge pull request #14881 from skye:version1aa08fd
Update WORKSPACE and setup.py for jax/jaxlib 0.4.6 release7fd1e2f
Split _src/traceback_util.py into its own Bazel target.01b00c4
Increase sharding of shard_map test on CPU.9912a8e
Split _src/pretty_printer.py into its own Bazel target.d216d98
Remove the disassemble into single devices arrays in ExecuteReplicated.__call...08789fd
Exclude "util.py" and "config.py" from the main JAX bazel target.560fe73
[jax2tf] Disable some failing tests0e05a79
Split some submodules out of //jax under Bazel.5c91453
[jax2tf] Add check that native lowering should not include custom calls not g...Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting
@dependabot rebase
.Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)