hand
This commit is contained in:
@@ -0,0 +1,506 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Iterable
|
||||
import dataclasses
|
||||
from typing import Any, TypeVar, TYPE_CHECKING
|
||||
|
||||
from jax._src import tree_util
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def all(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> bool:
|
||||
"""Call all() over the leaves of a tree.
|
||||
|
||||
Args:
|
||||
tree: the pytree to evaluate
|
||||
is_leaf : an optionally specified function that will be called at each
|
||||
flattening step. It should return a boolean, which indicates whether the
|
||||
flattening should traverse the current object, or if it should be stopped
|
||||
immediately, with the whole subtree being treated as a leaf.
|
||||
|
||||
Returns:
|
||||
result: boolean True or False
|
||||
|
||||
Examples:
|
||||
>>> import jax
|
||||
>>> jax.tree.all([True, {'a': True, 'b': (True, True)}])
|
||||
True
|
||||
>>> jax.tree.all([False, (True, False)])
|
||||
False
|
||||
|
||||
See Also:
|
||||
- :func:`jax.tree.reduce`
|
||||
- :func:`jax.tree.leaves`
|
||||
"""
|
||||
return tree_util.tree_all(tree, is_leaf=is_leaf)
|
||||
|
||||
|
||||
def flatten(tree: Any,
|
||||
is_leaf: Callable[[Any], bool] | None = None
|
||||
) -> tuple[list[tree_util.Leaf], tree_util.PyTreeDef]:
|
||||
"""Flattens a pytree.
|
||||
|
||||
The flattening order (i.e. the order of elements in the output list)
|
||||
is deterministic, corresponding to a left-to-right depth-first tree
|
||||
traversal.
|
||||
|
||||
Args:
|
||||
tree: a pytree to flatten.
|
||||
is_leaf: an optionally specified function that will be called at each
|
||||
flattening step. It should return a boolean, with true stopping the
|
||||
traversal and the whole subtree being treated as a leaf, and false
|
||||
indicating the flattening should traverse the current object.
|
||||
|
||||
Returns:
|
||||
A pair where the first element is a list of leaf values and the second
|
||||
element is a treedef representing the structure of the flattened tree.
|
||||
|
||||
Examples:
|
||||
>>> import jax
|
||||
>>> vals, treedef = jax.tree.flatten([1, (2, 3), [4, 5]])
|
||||
>>> vals
|
||||
[1, 2, 3, 4, 5]
|
||||
>>> treedef
|
||||
PyTreeDef([*, (*, *), [*, *]])
|
||||
|
||||
See Also:
|
||||
- :func:`jax.tree.leaves`
|
||||
- :func:`jax.tree.structure`
|
||||
- :func:`jax.tree.unflatten`
|
||||
"""
|
||||
return tree_util.tree_flatten(tree, is_leaf)
|
||||
|
||||
|
||||
def leaves(tree: Any,
|
||||
is_leaf: Callable[[Any], bool] | None = None
|
||||
) -> list[tree_util.Leaf]:
|
||||
"""Gets the leaves of a pytree.
|
||||
|
||||
Args:
|
||||
tree: the pytree for which to get the leaves
|
||||
is_leaf : an optionally specified function that will be called at each
|
||||
flattening step. It should return a boolean, which indicates whether the
|
||||
flattening should traverse the current object, or if it should be stopped
|
||||
immediately, with the whole subtree being treated as a leaf.
|
||||
|
||||
Returns:
|
||||
leaves: a list of tree leaves.
|
||||
|
||||
Examples:
|
||||
>>> import jax
|
||||
>>> jax.tree.leaves([1, (2, 3), [4, 5]])
|
||||
[1, 2, 3, 4, 5]
|
||||
|
||||
See Also:
|
||||
- :func:`jax.tree.flatten`
|
||||
- :func:`jax.tree.structure`
|
||||
- :func:`jax.tree.unflatten`
|
||||
"""
|
||||
return tree_util.tree_leaves(tree, is_leaf)
|
||||
|
||||
|
||||
def map(f: Callable[..., Any],
|
||||
tree: Any,
|
||||
*rest: Any,
|
||||
is_leaf: Callable[[Any], bool] | None = None) -> Any:
|
||||
"""Maps a multi-input function over pytree args to produce a new pytree.
|
||||
|
||||
Args:
|
||||
f: function that takes ``1 + len(rest)`` arguments, to be applied at the
|
||||
corresponding leaves of the pytrees.
|
||||
tree: a pytree to be mapped over, with each leaf providing the first
|
||||
positional argument to ``f``.
|
||||
rest: a tuple of pytrees, each of which has the same structure as ``tree``
|
||||
or has ``tree`` as a prefix.
|
||||
is_leaf: an optionally specified function that will be called at each
|
||||
flattening step. It should return a boolean, which indicates whether the
|
||||
flattening should traverse the current object, or if it should be stopped
|
||||
immediately, with the whole subtree being treated as a leaf.
|
||||
|
||||
Returns:
|
||||
A new pytree with the same structure as ``tree`` but with the value at each
|
||||
leaf given by ``f(x, *xs)`` where ``x`` is the value at the corresponding
|
||||
leaf in ``tree`` and ``xs`` is the tuple of values at corresponding nodes in
|
||||
``rest``.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> import jax
|
||||
>>> jax.tree.map(lambda x: x + 1, {"x": 7, "y": 42})
|
||||
{'x': 8, 'y': 43}
|
||||
|
||||
If multiple inputs are passed, the structure of the tree is taken from the
|
||||
first input; subsequent inputs need only have ``tree`` as a prefix:
|
||||
|
||||
>>> jax.tree.map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]])
|
||||
[[5, 7, 9], [6, 1, 2]]
|
||||
|
||||
See Also:
|
||||
- :func:`jax.tree.leaves`
|
||||
- :func:`jax.tree.reduce`
|
||||
"""
|
||||
return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
|
||||
|
||||
|
||||
def reduce(function: Callable[[T, Any], T],
|
||||
tree: Any,
|
||||
initializer: T | tree_util.Unspecified = tree_util.Unspecified(),
|
||||
is_leaf: Callable[[Any], bool] | None = None) -> T:
|
||||
"""Call reduce() over the leaves of a tree.
|
||||
|
||||
Args:
|
||||
function: the reduction function
|
||||
tree: the pytree to reduce over
|
||||
initializer: the optional initial value
|
||||
is_leaf : an optionally specified function that will be called at each
|
||||
flattening step. It should return a boolean, which indicates whether the
|
||||
flattening should traverse the current object, or if it should be stopped
|
||||
immediately, with the whole subtree being treated as a leaf.
|
||||
|
||||
Returns:
|
||||
result: the reduced value.
|
||||
|
||||
Examples:
|
||||
>>> import jax
|
||||
>>> import operator
|
||||
>>> jax.tree.reduce(operator.add, [1, (2, 3), [4, 5, 6]])
|
||||
21
|
||||
|
||||
Notes:
|
||||
**Tip**: You can exclude leaves from the reduction by first mapping them to
|
||||
``None`` using :func:`jax.tree.map`. This causes them to not be counted as
|
||||
leaves after that.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.tree.reduce_associative`
|
||||
- :func:`jax.tree.leaves`
|
||||
- :func:`jax.tree.map`
|
||||
"""
|
||||
return tree_util.tree_reduce(function, tree, initializer, is_leaf=is_leaf)
|
||||
|
||||
|
||||
def reduce_associative(
|
||||
operation: Callable[[T, T], T],
|
||||
tree: Any,
|
||||
*,
|
||||
identity: T | tree_util.Unspecified = tree_util.Unspecified(),
|
||||
is_leaf: Callable[[Any], bool] | None = None,
|
||||
) -> T:
|
||||
"""Perform a reduction over a pytree with an associative binary operation.
|
||||
|
||||
This function exploits the fact that the operation is associative to perform
|
||||
the reduction in parallel (logarithmic depth).
|
||||
|
||||
Args:
|
||||
operation: the associative binary operation
|
||||
tree: the pytree to reduce
|
||||
identity: the identity element of the associative binary operation.
|
||||
This is used only when the tree is empty. It is optional otherwise.
|
||||
is_leaf: an optionally specified function that will be called at each
|
||||
flattening step. It should return a boolean, which indicates whether the
|
||||
flattening should traverse the current object, or if it should be stopped
|
||||
immediately, with the whole subtree being treated as a leaf.
|
||||
|
||||
Returns:
|
||||
result: the reduced value
|
||||
|
||||
Examples:
|
||||
>>> import jax
|
||||
>>> import operator
|
||||
>>> jax.tree.reduce_associative(operator.add, [1, (2, 3), [4, 5, 6]])
|
||||
21
|
||||
|
||||
Notes:
|
||||
**Tip**: You can exclude leaves from the reduction by first mapping them to
|
||||
``None`` using :func:`jax.tree.map`. This causes them to not be counted as
|
||||
leaves after that.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.tree.reduce`
|
||||
"""
|
||||
return tree_util.tree_reduce_associative(
|
||||
operation,
|
||||
tree,
|
||||
identity=identity,
|
||||
is_leaf=is_leaf,
|
||||
)
|
||||
|
||||
|
||||
def structure(tree: Any,
|
||||
is_leaf: None | (Callable[[Any], bool]) = None) -> tree_util.PyTreeDef:
|
||||
"""Gets the treedef for a pytree.
|
||||
|
||||
Args:
|
||||
tree: the pytree for which to get the leaves
|
||||
is_leaf : an optionally specified function that will be called at each
|
||||
flattening step. It should return a boolean, which indicates whether the
|
||||
flattening should traverse the current object, or if it should be stopped
|
||||
immediately, with the whole subtree being treated as a leaf.
|
||||
|
||||
Returns:
|
||||
pytreedef: a PyTreeDef representing the structure of the tree.
|
||||
|
||||
Examples:
|
||||
>>> import jax
|
||||
>>> jax.tree.structure([1, (2, 3), [4, 5]])
|
||||
PyTreeDef([*, (*, *), [*, *]])
|
||||
|
||||
See Also:
|
||||
- :func:`jax.tree.flatten`
|
||||
- :func:`jax.tree.leaves`
|
||||
- :func:`jax.tree.unflatten`
|
||||
"""
|
||||
return tree_util.tree_structure(tree, is_leaf)
|
||||
|
||||
|
||||
def transpose(outer_treedef: tree_util.PyTreeDef,
|
||||
inner_treedef: tree_util.PyTreeDef | None,
|
||||
pytree_to_transpose: Any) -> Any:
|
||||
"""Transform a tree having tree structure (outer, inner) into one having structure (inner, outer).
|
||||
|
||||
Args:
|
||||
outer_treedef: PyTreeDef representing the outer tree.
|
||||
inner_treedef: PyTreeDef representing the inner tree.
|
||||
If None, then it will be inferred from outer_treedef and the structure of
|
||||
pytree_to_transpose.
|
||||
pytree_to_transpose: the pytree to be transposed.
|
||||
|
||||
Returns:
|
||||
transposed_pytree: the transposed pytree.
|
||||
|
||||
Examples:
|
||||
>>> import jax
|
||||
>>> tree = [(1, 2, 3), (4, 5, 6)]
|
||||
>>> inner_structure = jax.tree.structure(('*', '*', '*'))
|
||||
>>> outer_structure = jax.tree.structure(['*', '*'])
|
||||
>>> jax.tree.transpose(outer_structure, inner_structure, tree)
|
||||
([1, 4], [2, 5], [3, 6])
|
||||
|
||||
Inferring the inner structure:
|
||||
|
||||
>>> jax.tree.transpose(outer_structure, None, tree)
|
||||
([1, 4], [2, 5], [3, 6])
|
||||
"""
|
||||
return tree_util.tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose)
|
||||
|
||||
|
||||
def unflatten(treedef: tree_util.PyTreeDef,
|
||||
leaves: Iterable[tree_util.Leaf]) -> Any:
|
||||
"""Reconstructs a pytree from the treedef and the leaves.
|
||||
|
||||
The inverse of :func:`tree_flatten`.
|
||||
|
||||
Args:
|
||||
treedef: the treedef to reconstruct
|
||||
leaves: the iterable of leaves to use for reconstruction. The iterable must
|
||||
match the leaves of the treedef.
|
||||
|
||||
Returns:
|
||||
The reconstructed pytree, containing the ``leaves`` placed in the structure
|
||||
described by ``treedef``.
|
||||
|
||||
Examples:
|
||||
>>> import jax
|
||||
>>> vals, treedef = jax.tree.flatten([1, (2, 3), [4, 5]])
|
||||
>>> newvals = [100, 200, 300, 400, 500]
|
||||
>>> jax.tree.unflatten(treedef, newvals)
|
||||
[100, (200, 300), [400, 500]]
|
||||
|
||||
See Also:
|
||||
- :func:`jax.tree.flatten`
|
||||
- :func:`jax.tree.leaves`
|
||||
- :func:`jax.tree.structure`
|
||||
"""
|
||||
return tree_util.tree_unflatten(treedef, leaves)
|
||||
|
||||
|
||||
def flatten_with_path(
|
||||
tree: Any, is_leaf: Callable[..., bool] | None = None,
|
||||
is_leaf_takes_path: bool = False,
|
||||
) -> tuple[list[tuple[tree_util.KeyPath, Any]], tree_util.PyTreeDef]:
|
||||
"""Flattens a pytree like ``tree_flatten``, but also returns each leaf's key path.
|
||||
|
||||
Args:
|
||||
tree: a pytree to flatten. If it contains a custom type, it is recommended
|
||||
to be registered with ``register_pytree_with_keys``.
|
||||
|
||||
Returns:
|
||||
A pair which the first element is a list of key-leaf pairs, each of
|
||||
which contains a leaf and its key path. The second element is a treedef
|
||||
representing the structure of the flattened tree.
|
||||
|
||||
Examples:
|
||||
>>> import jax
|
||||
>>> path_vals, treedef = jax.tree.flatten_with_path([1, {'x': 3}])
|
||||
>>> path_vals
|
||||
[((SequenceKey(idx=0),), 1), ((SequenceKey(idx=1), DictKey(key='x')), 3)]
|
||||
>>> treedef
|
||||
PyTreeDef([*, {'x': *}])
|
||||
|
||||
See Also:
|
||||
- :func:`jax.tree.flatten`
|
||||
- :func:`jax.tree.map_with_path`
|
||||
- :func:`jax.tree_util.register_pytree_with_keys`
|
||||
"""
|
||||
return tree_util.tree_flatten_with_path(tree, is_leaf, is_leaf_takes_path)
|
||||
|
||||
|
||||
def leaves_with_path(
|
||||
tree: Any, is_leaf: Callable[..., bool] | None = None,
|
||||
is_leaf_takes_path: bool = False,
|
||||
) -> list[tuple[tree_util.KeyPath, Any]]:
|
||||
"""Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path.
|
||||
|
||||
Args:
|
||||
tree: a pytree. If it contains a custom type, it is recommended to be
|
||||
registered with ``register_pytree_with_keys``.
|
||||
|
||||
Returns:
|
||||
A list of key-leaf pairs, each of which contains a leaf and its key path.
|
||||
|
||||
Examples:
|
||||
>>> import jax
|
||||
>>> jax.tree.leaves_with_path([1, {'x': 3}])
|
||||
[((SequenceKey(idx=0),), 1), ((SequenceKey(idx=1), DictKey(key='x')), 3)]
|
||||
|
||||
See Also:
|
||||
- :func:`jax.tree.leaves`
|
||||
- :func:`jax.tree.flatten_with_path`
|
||||
- :func:`jax.tree_util.register_pytree_with_keys`
|
||||
"""
|
||||
return tree_util.tree_leaves_with_path(tree, is_leaf, is_leaf_takes_path)
|
||||
|
||||
|
||||
def map_with_path(
|
||||
f: Callable[..., Any],
|
||||
tree: Any,
|
||||
*rest: Any,
|
||||
is_leaf: Callable[..., bool] | None = None,
|
||||
is_leaf_takes_path: bool = False,
|
||||
) -> Any:
|
||||
"""Maps a multi-input function over pytree key path and args to produce a new pytree.
|
||||
|
||||
This is a more powerful alternative of ``tree_map`` that can take the key path
|
||||
of each leaf as input argument as well.
|
||||
|
||||
Args:
|
||||
f: function that takes ``2 + len(rest)`` arguments, aka. the key path and
|
||||
each corresponding leaves of the pytrees.
|
||||
tree: a pytree to be mapped over, with each leaf's key path as the first
|
||||
positional argument and the leaf itself as the second argument to ``f``.
|
||||
*rest: a tuple of pytrees, each of which has the same structure as ``tree``
|
||||
or has ``tree`` as a prefix.
|
||||
|
||||
Returns:
|
||||
A new pytree with the same structure as ``tree`` but with the value at each
|
||||
leaf given by ``f(kp, x, *xs)`` where ``kp`` is the key path of the leaf at
|
||||
the corresponding leaf in ``tree``, ``x`` is the leaf value and ``xs`` is
|
||||
the tuple of values at corresponding nodes in ``rest``.
|
||||
|
||||
Examples:
|
||||
>>> import jax
|
||||
>>> jax.tree.map_with_path(lambda path, x: x + path[0].idx, [1, 2, 3])
|
||||
[1, 3, 5]
|
||||
|
||||
See Also:
|
||||
- :func:`jax.tree.map`
|
||||
- :func:`jax.tree.flatten_with_path`
|
||||
- :func:`jax.tree.leaves_with_path`
|
||||
- :func:`jax.tree_util.register_pytree_with_keys`
|
||||
"""
|
||||
return tree_util.tree_map_with_path(
|
||||
f, tree, *rest, is_leaf=is_leaf, is_leaf_takes_path=is_leaf_takes_path
|
||||
)
|
||||
|
||||
|
||||
def broadcast(prefix_tree: Any, full_tree: Any,
|
||||
is_leaf: Callable[[Any], bool] | None = None
|
||||
) -> Any:
|
||||
"""Broadcasts a tree prefix into the full structure of a given tree.
|
||||
|
||||
Args:
|
||||
prefix_tree: a pytree that is a tree prefix of full_tree.
|
||||
full_tree: a pytree with the structure to broadcast the prefix leaves into.
|
||||
is_leaf: an optionally specified function that will be called at each
|
||||
flattening step. It should return a boolean, with true stopping the
|
||||
traversal and the whole subtree being treated as a leaf, and false
|
||||
indicating the flattening should traverse the current object.
|
||||
|
||||
Returns:
|
||||
A pytree matching the structure of full_tree where the leaves of prefix_tree have been
|
||||
broadcasted into the leaves of each corresponding subtree.
|
||||
|
||||
Examples:
|
||||
>>> import jax
|
||||
>>> prefix = (1, 2, 3)
|
||||
>>> full = (0, {'a': 0, 'b': 0}, (0, 0))
|
||||
>>> jax.tree.broadcast(prefix, full)
|
||||
(1, {'a': 2, 'b': 2}, (3, 3))
|
||||
|
||||
See Also:
|
||||
- :func:`jax.tree.leaves`
|
||||
- :func:`jax.tree.structure`
|
||||
"""
|
||||
return tree_util.tree_broadcast(prefix_tree, full_tree, is_leaf=is_leaf)
|
||||
|
||||
|
||||
# dataclasses.field is specially handled by static type checkers
|
||||
# (see https://peps.python.org/pep-0681/). In order to make static()
|
||||
# usable with built-in dataclass, the type checker needs to recognize
|
||||
# that it's identical to `dataclasses.field`. Since there is no
|
||||
# general registration mechanism for this, we define it this way:
|
||||
if TYPE_CHECKING:
|
||||
static = dataclasses.field
|
||||
else:
|
||||
def static(**kwargs):
|
||||
"""Convenience wrapper to declare a static pytree attribute.
|
||||
|
||||
Arguments are the same as those of :func:`dataclasses.field`, but
|
||||
:func:`static` will automatically populate `metadata` with
|
||||
`static = True`, as used by :func:`jax.tree_util.register_dataclass`.
|
||||
|
||||
Example:
|
||||
|
||||
>>> import jax
|
||||
>>> from dataclasses import dataclass
|
||||
...
|
||||
>>> @jax.tree_util.register_dataclass
|
||||
... @dataclass
|
||||
... class MyOp:
|
||||
... x: jax.Array
|
||||
... y: jax.Array
|
||||
... op: str = jax.tree.static(default="add") # static string field
|
||||
...
|
||||
>>> m = MyOp(x=jnp.ones(3), y=jnp.arange(3))
|
||||
>>> m
|
||||
MyOp(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
|
||||
|
||||
>>> leaves, treedef = jax.tree.flatten(m)
|
||||
>>> leaves
|
||||
[Array([1., 1., 1.], dtype=float32), Array([0, 1, 2], dtype=int32)]
|
||||
|
||||
>>> treedef
|
||||
PyTreeDef(CustomNode(MyOp[('add',)], [*, *]))
|
||||
|
||||
>>> jax.tree.unflatten(treedef, leaves)
|
||||
MyOp(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
|
||||
|
||||
See also:
|
||||
- :func:`jax.tree_util.register_dataclass`
|
||||
"""
|
||||
metadata = {"static": True, **(kwargs.pop('metadata', {}) or {})}
|
||||
return dataclasses.field(metadata=metadata, **kwargs)
|
||||
Reference in New Issue
Block a user