1610 lines
59 KiB
Python
1610 lines
59 KiB
Python
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from __future__ import annotations
|
|
|
|
import collections
|
|
from collections.abc import Callable, Hashable, Iterable, Sequence
|
|
import dataclasses
|
|
import difflib
|
|
import functools
|
|
from functools import cached_property, partial
|
|
import operator as op
|
|
import textwrap
|
|
from typing import Any, TypeVar
|
|
|
|
from jax._src import traceback_util
|
|
from jax._src.lib import jaxlib_extension_version
|
|
from jax._src.lib import pytree
|
|
from jax._src.util import safe_zip, set_module
|
|
from jax._src.util import unzip2
|
|
|
|
|
|
export = set_module('jax.tree_util')
|
|
|
|
traceback_util.register_exclusion(__file__)
|
|
|
|
T = TypeVar("T")
|
|
Typ = TypeVar("Typ", bound=type[Any])
|
|
H = TypeVar("H", bound=Hashable)
|
|
|
|
Leaf = Any
|
|
PyTree = Any
|
|
PyTreeDef = pytree.PyTreeDef
|
|
|
|
default_registry = pytree.default_registry()
|
|
# Set __module__ and __name__, which allow this registry to be pickled by
|
|
# reference.
|
|
default_registry.__module__ = __name__
|
|
default_registry.__name__ = "default_registry" # pyrefly: ignore[missing-attribute]
|
|
|
|
# A copy of the default registry, where None is a leaf.
|
|
none_leaf_registry = pytree.PyTreeRegistry(
|
|
enable_none=False, enable_tuple=True, enable_namedtuple=True,
|
|
enable_list=True, enable_dict=True)
|
|
none_leaf_registry.__module__ = __name__
|
|
none_leaf_registry.__name__ = "none_leaf_registry" # pyrefly: ignore[missing-attribute]
|
|
|
|
# A special, internal pytree registry that includes everything in
|
|
# `default_registry`, plus internal Python-defined types that we want
|
|
# to teach the fast dispatch path ("C++ dispatch") how to flatten and
|
|
# unflatten. A key example is PRNG key arrays, which are currently a
|
|
# Python-defined class (in `jax._src.prng`). These ought to be a leaf
|
|
# node everywhere in the system (e.g. in Jaxpr), but we want to unpack
|
|
# and repack them across the fast dispatch boundary. If we were to
|
|
# skip registering such types here, the fast dispatch path would not
|
|
# know how to handle them as arguments. It would instead always
|
|
# indicate a "cache miss" and dispatch on the slow path.
|
|
dispatch_registry = pytree.PyTreeRegistry(
|
|
enable_none=True, enable_tuple=True, enable_namedtuple=True,
|
|
enable_list=True, enable_dict=True)
|
|
dispatch_registry.__module__ = __name__
|
|
dispatch_registry.__name__ = "dispatch_registry" # pyrefly: ignore[missing-attribute]
|
|
|
|
tracing_registry = pytree.PyTreeRegistry()
|
|
tracing_registry.__module__ = __name__
|
|
tracing_registry.__name__ = "tracing_registry" # pyrefly: ignore[missing-attribute]
|
|
|
|
|
|
_all_registries = (
|
|
default_registry,
|
|
none_leaf_registry,
|
|
dispatch_registry,
|
|
tracing_registry,
|
|
)
|
|
|
|
|
|
@export
|
|
def tree_flatten(tree: Any,
|
|
is_leaf: Callable[[Any], bool] | None = None
|
|
) -> tuple[list[Leaf], PyTreeDef]:
|
|
"""Alias of :func:`jax.tree.flatten`."""
|
|
return default_registry.flatten(tree, is_leaf)
|
|
|
|
|
|
@export
|
|
def tree_unflatten(treedef: PyTreeDef, leaves: Iterable[Leaf]) -> Any:
|
|
"""Alias of :func:`jax.tree.unflatten`."""
|
|
return treedef.unflatten(leaves)
|
|
|
|
|
|
@export
|
|
def tree_leaves(tree: Any,
|
|
is_leaf: Callable[[Any], bool] | None = None
|
|
) -> list[Leaf]:
|
|
"""Alias of :func:`jax.tree.leaves`."""
|
|
return default_registry.flatten(tree, is_leaf)[0]
|
|
|
|
|
|
@export
|
|
def tree_leaves_checked(treedef_expected: PyTreeDef, tree: Any) -> list[Leaf]:
|
|
flat_vals, treedef_actual = tracing_registry.flatten(tree)
|
|
assert treedef_actual == treedef_expected
|
|
return flat_vals
|
|
|
|
|
|
@export
|
|
def tree_structure(tree: Any,
|
|
is_leaf: None | (Callable[[Any],
|
|
bool]) = None) -> PyTreeDef:
|
|
"""Alias of :func:`jax.tree.structure`."""
|
|
return default_registry.flatten(tree, is_leaf)[1]
|
|
|
|
|
|
@export
|
|
def treedef_tuple(treedefs: Iterable[PyTreeDef]) -> PyTreeDef:
|
|
"""Makes a tuple treedef from an iterable of child treedefs.
|
|
|
|
Args:
|
|
treedefs: iterable of PyTree structures
|
|
|
|
Returns:
|
|
a single treedef representing a tuple of the structures
|
|
|
|
Examples:
|
|
>>> import jax
|
|
>>> x = [1, 2, 3]
|
|
>>> y = {'a': 4, 'b': 5}
|
|
>>> x_tree = jax.tree.structure(x)
|
|
>>> y_tree = jax.tree.structure(y)
|
|
>>> xy_tree = jax.tree_util.treedef_tuple([x_tree, y_tree])
|
|
>>> xy_tree == jax.tree.structure((x, y))
|
|
True
|
|
|
|
See Also:
|
|
- :func:`jax.tree_util.treedef_children`
|
|
"""
|
|
return pytree.treedef_tuple(default_registry, list(treedefs))
|
|
|
|
|
|
@export
|
|
def treedef_children(treedef: PyTreeDef) -> list[PyTreeDef]:
|
|
"""Return a list of treedefs for immediate children
|
|
|
|
Args:
|
|
treedef: a single PyTreeDef
|
|
|
|
Returns:
|
|
a list of PyTreeDefs representing the children of treedef.
|
|
|
|
Examples:
|
|
>>> import jax
|
|
>>> x = [(1, 2), 3, {'a': 4}]
|
|
>>> treedef = jax.tree.structure(x)
|
|
>>> jax.tree_util.treedef_children(treedef)
|
|
[PyTreeDef((*, *)), PyTreeDef(*), PyTreeDef({'a': *})]
|
|
>>> _ == [jax.tree.structure(vals) for vals in x]
|
|
True
|
|
|
|
See Also:
|
|
- :func:`jax.tree_util.treedef_tuple`
|
|
"""
|
|
return treedef.children()
|
|
|
|
|
|
@export
|
|
def treedef_is_leaf(treedef: PyTreeDef) -> bool:
|
|
"""Return True if the treedef represents a leaf.
|
|
|
|
Args:
|
|
treedef: tree to check
|
|
|
|
Returns:
|
|
True if treedef is a leaf (i.e. has a single node); False otherwise.
|
|
|
|
Examples:
|
|
>>> import jax
|
|
>>> tree1 = jax.tree.structure(1)
|
|
>>> jax.tree_util.treedef_is_leaf(tree1)
|
|
True
|
|
>>> tree2 = jax.tree.structure([1, 2])
|
|
>>> jax.tree_util.treedef_is_leaf(tree2)
|
|
False
|
|
"""
|
|
return treedef.num_nodes == 1
|
|
|
|
|
|
# treedef_is_strict_leaf is not exported.
|
|
def treedef_is_strict_leaf(treedef: PyTreeDef) -> bool:
|
|
return treedef.num_nodes == 1 and treedef.num_leaves == 1
|
|
|
|
|
|
@export
|
|
def all_leaves(iterable: Iterable[Any],
|
|
is_leaf: Callable[[Any], bool] | None = None) -> bool:
|
|
"""Tests whether all elements in the given iterable are all leaves.
|
|
|
|
This function is useful in advanced cases, for example if a library allows
|
|
arbitrary map operations on a flat iterable of leaves it may want to check
|
|
if the result is still a flat iterable of leaves.
|
|
|
|
Args:
|
|
iterable: Iterable of leaves.
|
|
|
|
Returns:
|
|
A boolean indicating if all elements in the input are leaves.
|
|
|
|
Examples:
|
|
>>> import jax
|
|
>>> tree = {"a": [1, 2, 3]}
|
|
>>> assert all_leaves(jax.tree_util.tree_leaves(tree))
|
|
>>> assert not all_leaves([tree])
|
|
"""
|
|
if is_leaf is None:
|
|
return pytree.all_leaves(default_registry, iterable)
|
|
else:
|
|
items = list(iterable)
|
|
leaves = tree_leaves(items, is_leaf)
|
|
return len(leaves) == len(items) and all(
|
|
item is leaf for item, leaf in zip(items, leaves, strict=True)
|
|
)
|
|
|
|
|
|
@export
|
|
def is_tree_node(typ: type) -> bool:
|
|
"""Returns True if the type is a registered PyTree node type.
|
|
|
|
Args:
|
|
typ: The type to check.
|
|
|
|
Returns:
|
|
True if the type is a registered PyTree node type (built-in or custom)
|
|
or a namedtuple type.
|
|
"""
|
|
if jaxlib_extension_version >= 431:
|
|
return default_registry.is_node(typ)
|
|
return typ in _registry or issubclass(typ, tuple)
|
|
|
|
|
|
_Children = TypeVar("_Children", bound=Iterable[Any])
|
|
_AuxData = TypeVar("_AuxData", bound=Hashable)
|
|
KeyEntry = TypeVar("KeyEntry", bound=Any)
|
|
KeyLeafPair = tuple[KeyEntry, Any]
|
|
KeyLeafPairs = Iterable[KeyLeafPair]
|
|
KeyPath = tuple[KeyEntry, ...]
|
|
|
|
|
|
@export
|
|
def register_pytree_node(
|
|
nodetype: type[T],
|
|
flatten_func: Callable[[T], tuple[_Children, _AuxData]],
|
|
unflatten_func: Callable[[_AuxData, _Children], T],
|
|
flatten_with_keys_func: (
|
|
Callable[[T], tuple[KeyLeafPairs, _AuxData]] | None
|
|
) = None,
|
|
) -> None:
|
|
"""Extends the set of types that are considered internal nodes in pytrees.
|
|
|
|
See :ref:`example usage <pytrees>`.
|
|
|
|
Args:
|
|
nodetype: a Python type to register as a pytree.
|
|
flatten_func: a function to be used during flattening, taking a value of
|
|
type ``nodetype`` and returning a pair, with (1) an iterable for the
|
|
children to be flattened recursively, and (2) some hashable auxiliary data
|
|
to be stored in the treedef and to be passed to the ``unflatten_func``.
|
|
unflatten_func: a function taking two arguments: the auxiliary data that was
|
|
returned by ``flatten_func`` and stored in the treedef, and the
|
|
unflattened children. The function should return an instance of
|
|
``nodetype``.
|
|
|
|
See also:
|
|
- :func:`~jax.tree_util.register_static`: simpler API for registering a static pytree.
|
|
- :func:`~jax.tree_util.register_dataclass`: simpler API for registering a dataclass.
|
|
- :func:`~jax.tree_util.register_pytree_with_keys`
|
|
- :func:`~jax.tree_util.register_pytree_node_class`
|
|
- :func:`~jax.tree_util.register_pytree_with_keys_class`
|
|
|
|
Examples:
|
|
First we'll define a custom type:
|
|
|
|
>>> class MyContainer:
|
|
... def __init__(self, size):
|
|
... self.x = jnp.zeros(size)
|
|
... self.y = jnp.ones(size)
|
|
... self.size = size
|
|
|
|
If we try using this in a JIT-compiled function, we'll get an error because JAX
|
|
does not yet know how to handle this type:
|
|
|
|
>>> m = MyContainer(size=5)
|
|
>>> def f(m):
|
|
... return m.x + m.y + jnp.arange(m.size)
|
|
>>> jax.jit(f)(m) # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
Traceback (most recent call last):
|
|
...
|
|
TypeError: Cannot interpret value of type <class 'jax.tree_util.MyContainer'> as an abstract array; it does not have a dtype attribute
|
|
|
|
In order to make our object recognized by JAX, we must register it as
|
|
a pytree:
|
|
|
|
>>> def flatten_func(obj):
|
|
... children = (obj.x, obj.y) # children must contain arrays & pytrees
|
|
... aux_data = (obj.size,) # aux_data must contain static, hashable data.
|
|
... return (children, aux_data)
|
|
...
|
|
>>> def unflatten_func(aux_data, children):
|
|
... # Here we avoid `__init__` because it has extra logic we don't require:
|
|
... obj = object.__new__(MyContainer)
|
|
... obj.x, obj.y = children
|
|
... obj.size, = aux_data
|
|
... return obj
|
|
...
|
|
>>> jax.tree_util.register_pytree_node(MyContainer, flatten_func, unflatten_func)
|
|
|
|
Now with this defined, we can use instances of this type in JIT-compiled functions.
|
|
|
|
>>> jax.jit(f)(m)
|
|
Array([1., 2., 3., 4., 5.], dtype=float32)
|
|
"""
|
|
for registry in _all_registries:
|
|
registry.register_node(
|
|
nodetype, flatten_func, unflatten_func, flatten_with_keys_func
|
|
)
|
|
_registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func)
|
|
|
|
|
|
@export
|
|
def register_pytree_node_class(cls: Typ) -> Typ:
|
|
"""Extends the set of types that are considered internal nodes in pytrees.
|
|
|
|
This function is a thin wrapper around ``register_pytree_node``, and provides
|
|
a class-oriented interface.
|
|
|
|
Args:
|
|
cls: a type to register as a pytree
|
|
|
|
Returns:
|
|
The input class ``cls`` is returned unchanged after being added to JAX's pytree
|
|
registry. This return value allows ``register_pytree_node_class`` to be used as
|
|
a decorator.
|
|
|
|
See also:
|
|
- :func:`~jax.tree_util.register_static`: simpler API for registering a static pytree.
|
|
- :func:`~jax.tree_util.register_dataclass`: simpler API for registering a dataclass.
|
|
- :func:`~jax.tree_util.register_pytree_node`
|
|
- :func:`~jax.tree_util.register_pytree_with_keys`
|
|
- :func:`~jax.tree_util.register_pytree_with_keys_class`
|
|
|
|
Examples:
|
|
Here we'll define a custom container that will be compatible with :func:`jax.jit`
|
|
and other JAX transformations:
|
|
|
|
>>> import jax
|
|
>>> @jax.tree_util.register_pytree_node_class
|
|
... class MyContainer:
|
|
... def __init__(self, x, y):
|
|
... self.x = x
|
|
... self.y = y
|
|
... def tree_flatten(self):
|
|
... return ((self.x, self.y), None)
|
|
... @classmethod
|
|
... def tree_unflatten(cls, aux_data, children):
|
|
... return cls(*children)
|
|
...
|
|
>>> m = MyContainer(jnp.zeros(4), jnp.arange(4))
|
|
>>> def f(m):
|
|
... return m.x + 2 * m.y
|
|
>>> jax.jit(f)(m)
|
|
Array([0., 2., 4., 6.], dtype=float32)
|
|
"""
|
|
register_pytree_node(
|
|
cls,
|
|
op.methodcaller("tree_flatten"),
|
|
cls.tree_unflatten # pyrefly: ignore[missing-attribute]
|
|
)
|
|
return cls
|
|
|
|
|
|
@export
|
|
def tree_map(f: Callable[..., Any],
|
|
tree: Any,
|
|
*rest: Any,
|
|
is_leaf: Callable[[Any], bool] | None = None) -> Any:
|
|
"""Alias of :func:`jax.tree.map`."""
|
|
leaves, treedef = tree_flatten(tree, is_leaf)
|
|
all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
|
|
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
|
|
|
|
|
|
@export
|
|
def tree_transpose(outer_treedef: PyTreeDef, inner_treedef: PyTreeDef | None,
|
|
pytree_to_transpose: Any) -> Any:
|
|
"""Alias of :func:`jax.tree.transpose`."""
|
|
flat, treedef = tree_flatten(pytree_to_transpose)
|
|
if inner_treedef is None:
|
|
inner_treedef = tree_structure(outer_treedef.flatten_up_to(pytree_to_transpose)[0])
|
|
inner_size = inner_treedef.num_leaves
|
|
outer_size = outer_treedef.num_leaves
|
|
if treedef.num_leaves != (inner_size * outer_size):
|
|
expected_treedef = outer_treedef.compose(inner_treedef)
|
|
raise TypeError(f"Mismatch\n{treedef}\n != \n{expected_treedef}")
|
|
iter_flat = iter(flat)
|
|
lol = [
|
|
[next(iter_flat) for _ in range(inner_size)] for __ in range(outer_size)
|
|
]
|
|
transposed_lol = zip(*lol)
|
|
subtrees = map(partial(tree_unflatten, outer_treedef), transposed_lol)
|
|
return tree_unflatten(inner_treedef, subtrees)
|
|
|
|
|
|
# TODO(mattjj): remove the Python-side registry when the C++-side registry is
|
|
# sufficiently queryable that we can express _replace_nones. That may mean once
|
|
# we have a flatten_one function.
|
|
_RegistryEntry = collections.namedtuple("_RegistryEntry", ["to_iter", "from_iter"])
|
|
_registry: dict[type[Any], _RegistryEntry] = {
|
|
tuple: _RegistryEntry(lambda xs: (xs, None), lambda _, xs: tuple(xs)),
|
|
list: _RegistryEntry(lambda xs: (xs, None), lambda _, xs: list(xs)),
|
|
dict: _RegistryEntry(lambda xs: unzip2(sorted(xs.items()))[::-1],
|
|
lambda keys, xs: dict(zip(keys, xs))),
|
|
type(None): _RegistryEntry(lambda z: ((), None), lambda _, xs: None),
|
|
}
|
|
|
|
|
|
class Unspecified:
|
|
pass
|
|
|
|
|
|
@export
|
|
def tree_reduce(function: Callable[[T, Any], T],
|
|
tree: Any,
|
|
initializer: T | Unspecified = Unspecified(),
|
|
is_leaf: Callable[[Any], bool] | None = None) -> T:
|
|
"""Alias of :func:`jax.tree.reduce`."""
|
|
if isinstance(initializer, Unspecified):
|
|
return functools.reduce(function, tree_leaves(tree, is_leaf=is_leaf))
|
|
else:
|
|
return functools.reduce(function, tree_leaves(tree, is_leaf=is_leaf), initializer)
|
|
|
|
|
|
def _parallel_reduce(
|
|
sequence: list[T],
|
|
operation: Callable[[T, T], T],
|
|
identity: T | Unspecified = Unspecified(),
|
|
) -> T:
|
|
length = len(sequence)
|
|
if length == 0:
|
|
if isinstance(identity, Unspecified):
|
|
raise TypeError("Must specify identity for parallel reduction of empty sequence.")
|
|
return identity
|
|
elif length == 1:
|
|
return sequence[0]
|
|
else:
|
|
index = length // 2
|
|
a = _parallel_reduce(sequence[:index], operation, identity)
|
|
b = _parallel_reduce(sequence[index:], operation, identity)
|
|
return operation(a, b)
|
|
|
|
|
|
@export
|
|
def tree_reduce_associative(
|
|
operation: Callable[[T, T], T],
|
|
tree: Any,
|
|
*,
|
|
identity: T | Unspecified = Unspecified(),
|
|
is_leaf: Callable[[Any], bool] | None = None,
|
|
) -> T:
|
|
"""Alias of :func:`jax.tree.reduce_associative`."""
|
|
sequence = tree_leaves(tree, is_leaf=is_leaf)
|
|
return _parallel_reduce(sequence, operation, identity)
|
|
|
|
|
|
@export
|
|
def tree_all(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> bool:
|
|
"""Alias of :func:`jax.tree.all`."""
|
|
return all(tree_leaves(tree, is_leaf=is_leaf))
|
|
|
|
|
|
class _HashableCallableShim:
|
|
"""Object that delegates __call__, __hash__, and __eq__ to another object."""
|
|
|
|
def __init__(self, fun):
|
|
self.fun = fun
|
|
|
|
def __call__(self, *args, **kw):
|
|
return self.fun(*args, **kw)
|
|
|
|
def __hash__(self):
|
|
return hash(self.fun)
|
|
|
|
def __eq__(self, other):
|
|
if isinstance(other, _HashableCallableShim):
|
|
return self.fun == other.fun
|
|
return self.fun == other
|
|
|
|
def __repr__(self):
|
|
return f'_HashableCallableShim({self.fun!r})'
|
|
|
|
|
|
@export
|
|
class Partial(functools.partial):
|
|
"""A version of functools.partial that works in pytrees.
|
|
|
|
Use it for partial function evaluation in a way that is compatible with JAX's
|
|
transformations, e.g., ``Partial(func, *args, **kwargs)``.
|
|
|
|
(You need to explicitly opt-in to this behavior because we didn't want to give
|
|
functools.partial different semantics than normal function closures.)
|
|
|
|
For example, here is a basic usage of ``Partial`` in a manner similar to
|
|
``functools.partial``:
|
|
|
|
>>> import jax.numpy as jnp
|
|
>>> add_one = Partial(jnp.add, 1)
|
|
>>> add_one(2)
|
|
Array(3, dtype=int32, weak_type=True)
|
|
|
|
Pytree compatibility means that the resulting partial function can be passed
|
|
as an argument within transformed JAX functions, which is not possible with a
|
|
standard ``functools.partial`` function:
|
|
|
|
>>> from jax import jit
|
|
>>> @jit
|
|
... def call_func(f, *args):
|
|
... return f(*args)
|
|
...
|
|
>>> call_func(add_one, 2)
|
|
Array(3, dtype=int32, weak_type=True)
|
|
|
|
Passing zero arguments to ``Partial`` effectively wraps the original function,
|
|
making it a valid argument in JAX transformed functions:
|
|
|
|
>>> call_func(Partial(jnp.add), 1, 2)
|
|
Array(3, dtype=int32, weak_type=True)
|
|
|
|
Had we passed ``jnp.add`` to ``call_func`` directly, it would have resulted in
|
|
a ``TypeError``.
|
|
|
|
Note that if the result of ``Partial`` is used in the context where the
|
|
value is traced, it results in all bound arguments being traced when passed
|
|
to the partially-evaluated function:
|
|
|
|
>>> print_zero = Partial(print, 0)
|
|
>>> print_zero()
|
|
0
|
|
>>> call_func(print_zero) # doctest:+ELLIPSIS
|
|
JitTracer(~int32[])
|
|
"""
|
|
|
|
def __new__(klass, func, *args, **kw):
|
|
# In Python 3.10+, if func is itself a functools.partial instance,
|
|
# functools.partial.__new__ would merge the arguments of this Partial
|
|
# instance with the arguments of the func. We box func in a class that does
|
|
# not (yet) have a `func` attribute to defeat this optimization, since we
|
|
# care exactly which arguments are considered part of the pytree.
|
|
if isinstance(func, functools.partial):
|
|
original_func = func
|
|
func = _HashableCallableShim(original_func)
|
|
out = super().__new__(klass, func, *args, **kw)
|
|
func.func = original_func.func # pyrefly: ignore[missing-attribute]
|
|
func.args = original_func.args # pyrefly: ignore[missing-attribute]
|
|
func.keywords = original_func.keywords # pyrefly: ignore[missing-attribute]
|
|
return out
|
|
else:
|
|
return super().__new__(klass, func, *args, **kw)
|
|
|
|
|
|
register_pytree_node(
|
|
Partial,
|
|
lambda partial_: ((partial_.args, partial_.keywords), partial_.func),
|
|
lambda func, xs: Partial(func, *xs[0], **xs[1]),
|
|
)
|
|
|
|
|
|
@export
|
|
def tree_broadcast(prefix_tree: Any, full_tree: Any,
|
|
is_leaf: Callable[[Any], bool] | None = None
|
|
) -> Any:
|
|
"""Alias of :func:`jax.tree.broadcast`."""
|
|
broadcast_leaves = broadcast_prefix(prefix_tree, full_tree, is_leaf=is_leaf)
|
|
return tree_structure(full_tree).unflatten(broadcast_leaves)
|
|
|
|
|
|
# broadcast_prefix is not exported
|
|
def broadcast_prefix(prefix_tree: Any, full_tree: Any,
|
|
is_leaf: Callable[[Any], bool] | None = None
|
|
) -> list[Any]:
|
|
"""Broadcasts tree prefix leaves into the full set of leaves for a given full tree.
|
|
|
|
Args:
|
|
prefix_tree: a pytree that is a tree prefix of full_tree.
|
|
full_tree: a pytree with the structure to broadcast the prefix leaves into.
|
|
is_leaf: an optionally specified function that will be called at each
|
|
flattening step for prefix_tree. It should return a boolean, with true
|
|
stopping the traversal and the whole subtree being treated as a leaf,
|
|
and false indicating the flattening should traverse the current object.
|
|
|
|
Returns:
|
|
A list of leaves matching the expected count for the full tree,
|
|
with the leaf of each prefix tree being duplicated to match the count of
|
|
its corresponding subtree.
|
|
"""
|
|
result = []
|
|
num_leaves = lambda t: tree_structure(t).num_leaves
|
|
add_leaves = lambda x, subtree: result.extend([x] * num_leaves(subtree))
|
|
try:
|
|
tree_map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf)
|
|
except ValueError:
|
|
e, *_ = prefix_errors(prefix_tree, full_tree)
|
|
raise e('broadcast_prefix prefix_tree') from None
|
|
return result
|
|
|
|
|
|
# broadcast_flattened_prefix_with_treedef is not exported
|
|
def broadcast_flattened_prefix_with_treedef(
|
|
prefix_leaves: list[Any],
|
|
prefix_treedef: PyTreeDef,
|
|
full_treedef: PyTreeDef,
|
|
) -> list[Any]:
|
|
"""Broadcasts tree prefix leaves into the full set of leaves for a given full treedef.
|
|
|
|
Args:
|
|
prefix_leaves: the leaves of a pytree that is a tree prefix
|
|
of full_treedef.
|
|
prefix_treedef: the PyTreeDef of a pytree that is a tree prefix of
|
|
full_treedef.
|
|
full_treedef: a PyTreeDef with the structure to broadcast the prefix
|
|
leaves into.
|
|
|
|
Returns:
|
|
A list of leaves matching the expected count for the full tree,
|
|
with each leaf of prefix tree being duplicated to match the count of
|
|
its corresponding subtree.
|
|
"""
|
|
# NOTE: At the moment, `broadcast_flattened_prefix_with_treedef` is only
|
|
# called from `api_util.flatten_axes`, which replaces any raised exception
|
|
# with its own exception and error message. The errors raised from this
|
|
# function should probably be improved before this function is used in
|
|
# more places.
|
|
#
|
|
# TODO(jburnim): Merge `broadcast_prefix` with this function?
|
|
# prefix_leaves, prefix_treedef = tree_flatten(prefix_tree, is_leaf)
|
|
ret = []
|
|
|
|
# TODO(jburnim): Should this traversal be done in C++?
|
|
def _broadcast(broadcast_fn, leaf_start, leaf_end, prefix_treedef, treedef):
|
|
if treedef_is_strict_leaf(prefix_treedef):
|
|
# We have encountered a leaf in the prefix, so we repeat the prefix leaf
|
|
# for each leaf in the corresponding part of the tree.
|
|
assert (leaf_end - leaf_start) == 1
|
|
ret.extend(prefix_leaves[leaf_start:leaf_end] * treedef.num_leaves)
|
|
return
|
|
|
|
if treedef_is_strict_leaf(treedef):
|
|
raise ValueError('`prefix_treedef` is not a prefix of `full_treedef`')
|
|
|
|
prefix_node_data = prefix_treedef.node_data()
|
|
node_data = treedef.node_data()
|
|
if prefix_node_data != node_data:
|
|
raise ValueError(f'expected {node_data}, got {prefix_node_data}')
|
|
|
|
prefix_i = leaf_start
|
|
for prefix_child, tree_child in zip(
|
|
prefix_treedef.children(), treedef.children(), strict=True):
|
|
broadcast_fn(broadcast_fn, prefix_i, prefix_i + prefix_child.num_leaves,
|
|
prefix_child, tree_child,
|
|
)
|
|
prefix_i += prefix_child.num_leaves
|
|
|
|
# Pass _broadcast as arg to avoid it being a free variable within its own
|
|
# closure, which creates a reference cycle.
|
|
_broadcast(_broadcast, 0, len(prefix_leaves), prefix_treedef, full_treedef)
|
|
return ret
|
|
|
|
|
|
# flatten_one_level is not exported.
|
|
def flatten_one_level(tree: Any) -> tuple[Iterable[Any], Hashable]:
|
|
"""Flatten the given pytree node by one level.
|
|
|
|
Args:
|
|
tree: A valid pytree node, either built-in or registered via
|
|
:func:`register_pytree_node` or related functions.
|
|
|
|
Returns:
|
|
A pair of the pytrees flattened children and its hashable metadata.
|
|
|
|
Raises:
|
|
ValueError: If the given pytree is not a built-in or registered container
|
|
via ``register_pytree_node`` or ``register_pytree_with_keys``.
|
|
|
|
Examples:
|
|
>>> import jax
|
|
>>> from jax._src.tree_util import flatten_one_level
|
|
>>> flattened, meta = flatten_one_level({'a': [1, 2], 'b': {'c': 3}})
|
|
>>> flattened
|
|
([1, 2], {'c': 3})
|
|
>>> meta
|
|
('a', 'b')
|
|
"""
|
|
out = default_registry.flatten_one_level(tree)
|
|
if out is None:
|
|
raise ValueError(f"can't tree-flatten type: {type(tree)}")
|
|
else:
|
|
return out
|
|
|
|
|
|
# flatten_one_level_with_keys is not exported.
|
|
def flatten_one_level_with_keys(
|
|
tree: Any,
|
|
) -> tuple[Iterable[KeyLeafPair], Hashable]:
|
|
"""Flatten the given pytree node by one level, with keys."""
|
|
out = default_registry.flatten_one_level_with_keys(tree)
|
|
if out is None:
|
|
raise ValueError(f"can't tree-flatten type: {type(tree)}")
|
|
else:
|
|
return out
|
|
|
|
|
|
# prefix_errors is not exported
|
|
def prefix_errors(prefix_tree: Any, full_tree: Any,
|
|
is_leaf: Callable[[Any], bool] | None = None,
|
|
) -> list[Callable[[str], ValueError]]:
|
|
return list(_prefix_error((), prefix_tree, full_tree, is_leaf))
|
|
|
|
|
|
# equality_errors is not exported
|
|
def equality_errors(
|
|
tree1: Any, tree2: Any, is_leaf: Callable[[Any], bool] | None = None,
|
|
) -> Iterable[tuple[KeyPath, str, str, str]]:
|
|
"""Helper to describe structural differences between two pytrees.
|
|
|
|
Args:
|
|
tree1, tree2: pytrees known to have different structure.
|
|
|
|
Usage:
|
|
|
|
raise Exception(
|
|
"Value 1 and value 2 must have the same pytree structure, but they have "
|
|
"the following structural differences:\n" +
|
|
("\n".join(
|
|
f" - {keystr(path)} is a {thing1} in value 1 and a {thing2} in "
|
|
f" value 2, so {explanation}.\n"
|
|
for path, thing1, thing2, explanation
|
|
in equality_errors(val1, val2))))
|
|
"""
|
|
yield from _equality_errors((), tree1, tree2, is_leaf)
|
|
|
|
def equality_errors_pytreedef(
|
|
tree1: PyTreeDef,
|
|
tree2: PyTreeDef) -> Iterable[tuple[KeyPath, str, str, str]]:
|
|
"""Like `equality_errors` but invoked on PyTreeDef."""
|
|
# TODO(mattjj): make equality_errors not print type name, avoid metaclass
|
|
leaf = type("LeafMeta", (type,), dict(__repr__=lambda _: "pytree leaf"))("Leaf", (), {})()
|
|
return equality_errors(tree_unflatten(tree1, [leaf] * tree1.num_leaves),
|
|
tree_unflatten(tree2, [leaf] * tree2.num_leaves))
|
|
|
|
# TODO(mattjj): maybe share some logic with _prefix_error?
|
|
def _equality_errors(path, t1, t2, is_leaf):
|
|
# If both are leaves, this isn't a structure equality error.
|
|
if (treedef_is_strict_leaf(tree_structure(t1, is_leaf=is_leaf)) and
|
|
treedef_is_strict_leaf(tree_structure(t2, is_leaf=is_leaf))): return
|
|
|
|
# The trees may disagree because they are different types:
|
|
if type(t1) != type(t2):
|
|
yield path, str(type(t1)), str(type(t2)), 'their Python types differ'
|
|
return # no more errors to find
|
|
|
|
# Or they may disagree because their roots have different numbers or keys of
|
|
# children (with special-case handling of list/tuple):
|
|
if isinstance(t1, (list, tuple)):
|
|
assert type(t1) == type(t2)
|
|
if len(t1) != len(t2):
|
|
yield (path,
|
|
f'{type(t1).__name__} of length {len(t1)}',
|
|
f'{type(t2).__name__} of length {len(t2)}',
|
|
'the lengths do not match')
|
|
return # no more errors to find
|
|
t1_children, t1_meta = flatten_one_level(t1)
|
|
t2_children, t2_meta = flatten_one_level(t2)
|
|
t1_children = tuple(t1_children)
|
|
t2_children = tuple(t2_children)
|
|
t1_keys, t2_keys = _child_keys(t1), _child_keys(t2)
|
|
try:
|
|
diff = ' '.join(repr(k.key) for k in
|
|
set(t1_keys).symmetric_difference(set(t2_keys)))
|
|
except:
|
|
diff = ''
|
|
if len(t1_children) != len(t2_children):
|
|
yield (path,
|
|
f'{type(t1)} with {len(t1_children)} child'
|
|
f'{"ren" if len(t1_children) > 1 else ""}',
|
|
f'{type(t2)} with {len(t2_children)} child'
|
|
f'{"ren" if len(t2_children) > 1 else ""}',
|
|
'the numbers of children do not match' +
|
|
(diff and f', with the symmetric difference of key sets: {{{diff}}}')
|
|
)
|
|
return # no more errors to find
|
|
|
|
# Or they may disagree if their roots have different pytree metadata:
|
|
if t1_meta != t2_meta:
|
|
yield (path,
|
|
f'{type(t1)} with pytree metadata {t1_meta}',
|
|
f'{type(t2)} with pytree metadata {t2_meta}',
|
|
'the pytree node metadata does not match')
|
|
return # no more errors to find
|
|
|
|
# If the root types and numbers of children agree, there must be a mismatch in
|
|
# a subtree, so recurse:
|
|
assert t1_keys == t2_keys, \
|
|
f"equal pytree nodes gave different tree keys: {t1_keys} and {t2_keys}"
|
|
for k, c1, c2 in zip(t1_keys, t1_children, t2_children):
|
|
yield from _equality_errors((*path, k), c1, c2, is_leaf)
|
|
|
|
|
|
SequenceKey: Any = pytree.SequenceKey
|
|
DictKey: Any = pytree.DictKey
|
|
GetAttrKey: Any = pytree.GetAttrKey
|
|
FlattenedIndexKey: Any = pytree.FlattenedIndexKey
|
|
|
|
|
|
@export
|
|
def keystr(keys: KeyPath, *, simple: bool = False, separator: str = '') -> str:
|
|
"""Helper to pretty-print a tuple of keys.
|
|
|
|
Args:
|
|
keys: A tuple of ``KeyEntry`` or any class that can be converted to string.
|
|
simple: If True, use a simplified string representation for keys. The
|
|
simple representation of keys will be more compact than the default, but
|
|
is ambiguous in some cases (for example "0" might refer to the first item
|
|
in a list or a dictionary key for the integer 0 or string "0").
|
|
separator: The separator to use to join string representations of the keys.
|
|
|
|
Returns:
|
|
A string that joins all string representations of the keys.
|
|
|
|
Examples:
|
|
>>> import jax
|
|
>>> params = {'foo': {'bar': {'baz': 1, 'bat': [2, 3]}}}
|
|
>>> for path, _ in jax.tree_util.tree_leaves_with_path(params):
|
|
... print(jax.tree_util.keystr(path))
|
|
['foo']['bar']['bat'][0]
|
|
['foo']['bar']['bat'][1]
|
|
['foo']['bar']['baz']
|
|
>>> for path, _ in jax.tree_util.tree_leaves_with_path(params):
|
|
... print(jax.tree_util.keystr(path, simple=True, separator='/'))
|
|
foo/bar/bat/0
|
|
foo/bar/bat/1
|
|
foo/bar/baz
|
|
"""
|
|
str_fn = _simple_entrystr if simple else str
|
|
return separator.join(map(str_fn, keys))
|
|
|
|
|
|
def _simple_entrystr(key: KeyEntry) -> str:
|
|
match key:
|
|
case (
|
|
SequenceKey(idx=key)
|
|
| DictKey(key=key)
|
|
| GetAttrKey(name=key)
|
|
| FlattenedIndexKey(key=key)
|
|
):
|
|
return str(key)
|
|
case _:
|
|
return str(key)
|
|
|
|
|
|
@export
|
|
def register_pytree_with_keys(
|
|
nodetype: type[T],
|
|
flatten_with_keys: Callable[[T], tuple[Iterable[KeyLeafPair], _AuxData]],
|
|
unflatten_func: Callable[[_AuxData, Iterable[Any]], T],
|
|
flatten_func: None | (Callable[[T], tuple[Iterable[Any], _AuxData]]) = None,
|
|
):
|
|
"""Extends the set of types that are considered internal nodes in pytrees.
|
|
|
|
This is a more powerful alternative to ``register_pytree_node`` that allows
|
|
you to access each pytree leaf's key path when flattening and tree-mapping.
|
|
|
|
Args:
|
|
nodetype: a Python type to treat as an internal pytree node.
|
|
flatten_with_keys: a function to be used during flattening, taking a value
|
|
of type ``nodetype`` and returning a pair, with (1) an iterable for tuples
|
|
of each key path and its child, and (2) some hashable auxiliary data to be
|
|
stored in the treedef and to be passed to the ``unflatten_func``.
|
|
unflatten_func: a function taking two arguments: the auxiliary data that was
|
|
returned by ``flatten_func`` and stored in the treedef, and the
|
|
unflattened children. The function should return an instance of
|
|
``nodetype``.
|
|
flatten_func: an optional function similar to ``flatten_with_keys``, but
|
|
returns only children and auxiliary data. It must return the children
|
|
in the same order as ``flatten_with_keys``, and return the same aux data.
|
|
This argument is optional and only needed for faster traversal when
|
|
calling functions without keys like ``tree_map`` and ``tree_flatten``.
|
|
|
|
Examples:
|
|
First we'll define a custom type:
|
|
|
|
>>> class MyContainer:
|
|
... def __init__(self, size):
|
|
... self.x = jnp.zeros(size)
|
|
... self.y = jnp.ones(size)
|
|
... self.size = size
|
|
|
|
Now register it using a key-aware flatten function:
|
|
|
|
>>> from jax.tree_util import register_pytree_with_keys_class, GetAttrKey
|
|
>>> def flatten_with_keys(obj):
|
|
... children = [(GetAttrKey('x'), obj.x),
|
|
... (GetAttrKey('y'), obj.y)] # children must contain arrays & pytrees
|
|
... aux_data = (obj.size,) # aux_data must contain static, hashable data.
|
|
... return children, aux_data
|
|
...
|
|
>>> def unflatten(aux_data, children):
|
|
... # Here we avoid `__init__` because it has extra logic we don't require:
|
|
... obj = object.__new__(MyContainer)
|
|
... obj.x, obj.y = children
|
|
... obj.size, = aux_data
|
|
... return obj
|
|
...
|
|
>>> jax.tree_util.register_pytree_node(MyContainer, flatten_with_keys, unflatten)
|
|
|
|
Now this can be used with functions like :func:`~jax.tree_util.tree_flatten_with_path`:
|
|
|
|
>>> m = MyContainer(4)
|
|
>>> leaves, treedef = jax.tree_util.tree_flatten_with_path(m)
|
|
"""
|
|
if not flatten_func:
|
|
def flatten_func_impl(tree):
|
|
key_children, treedef = flatten_with_keys(tree)
|
|
return [c for _, c in key_children], treedef
|
|
flatten_func = flatten_func_impl
|
|
|
|
register_pytree_node(
|
|
nodetype, flatten_func, unflatten_func, flatten_with_keys
|
|
)
|
|
|
|
|
|
@export
|
|
def register_pytree_with_keys_class(cls: Typ) -> Typ:
|
|
"""Extends the set of types that are considered internal nodes in pytrees.
|
|
|
|
This function is similar to ``register_pytree_node_class``, but requires a
|
|
class that defines how it could be flattened with keys.
|
|
|
|
It is a thin wrapper around ``register_pytree_with_keys``, and
|
|
provides a class-oriented interface:
|
|
|
|
Args:
|
|
cls: a type to register as a pytree
|
|
|
|
Returns:
|
|
The input class ``cls`` is returned unchanged after being added to JAX's pytree
|
|
registry. This return value allows ``register_pytree_node_class`` to be used as
|
|
a decorator.
|
|
|
|
See also:
|
|
- :func:`~jax.tree_util.register_static`: simpler API for registering a static pytree.
|
|
- :func:`~jax.tree_util.register_dataclass`: simpler API for registering a dataclass.
|
|
- :func:`~jax.tree_util.register_pytree_node`
|
|
- :func:`~jax.tree_util.register_pytree_with_keys`
|
|
- :func:`~jax.tree_util.register_pytree_node_class`
|
|
|
|
Examples:
|
|
>>> from jax.tree_util import register_pytree_with_keys_class, GetAttrKey
|
|
>>> @register_pytree_with_keys_class
|
|
... class Special:
|
|
... def __init__(self, x, y):
|
|
... self.x = x
|
|
... self.y = y
|
|
... def tree_flatten_with_keys(self):
|
|
... return (((GetAttrKey('x'), self.x), (GetAttrKey('y'), self.y)), None)
|
|
... @classmethod
|
|
... def tree_unflatten(cls, aux_data, children):
|
|
... return cls(*children)
|
|
"""
|
|
flatten_func = (
|
|
op.methodcaller("tree_flatten") if hasattr(cls, "tree_flatten") else None
|
|
)
|
|
register_pytree_with_keys(
|
|
cls, op.methodcaller("tree_flatten_with_keys"),
|
|
cls.tree_unflatten, # pyrefly: ignore[missing-attribute]
|
|
flatten_func
|
|
)
|
|
return cls
|
|
|
|
|
|
@export
|
|
def register_dataclass(
|
|
nodetype: Typ,
|
|
data_fields: Sequence[str] | None = None,
|
|
meta_fields: Sequence[str] | None = None,
|
|
drop_fields: Sequence[str] = (),
|
|
) -> Typ:
|
|
"""Extends the set of types that are considered internal nodes in pytrees.
|
|
|
|
This differs from ``register_pytree_with_keys_class`` in that the C++
|
|
registries use the optimized C++ dataclass builtin instead of the argument
|
|
functions.
|
|
|
|
See :ref:`pytrees-custom-pytree-nodes` for more information about registering pytrees.
|
|
|
|
Args:
|
|
nodetype: a Python type to treat as an internal pytree node. This is assumed
|
|
to have the semantics of a :obj:`~dataclasses.dataclass`: namely, class
|
|
attributes represent the whole of the object state, and can be passed
|
|
as keywords to the class constructor to create a copy of the object.
|
|
All defined attributes should be listed among ``meta_fields`` or ``data_fields``.
|
|
meta_fields: metadata field names: these are attributes which will be treated as
|
|
:term:`static` when this pytree is passed to :func:`jax.jit`. ``meta_fields`` is
|
|
optional only if ``nodetype`` is a dataclass, in which case individual fields can
|
|
be marked static via :func:`dataclasses.field` (see examples below).
|
|
Metadata fields *must* be static, hashable, immutable objects, as these objects
|
|
are used to generate JIT cache keys. In particular, metadata fields cannot contain
|
|
:class:`jax.Array` or :class:`numpy.ndarray` objects.
|
|
data_fields: data field names: these are attributes which will be treated as non-static
|
|
when this pytree is passed to :func:`jax.jit`. ``data_fields`` is optional only if
|
|
``nodetype`` is a dataclass, in which case fields are assumed data fields unless
|
|
marked via :func:`dataclasses.field` (see examples below).
|
|
Data fields *must* be JAX-compatible objects such as arrays (:class:`jax.Array`
|
|
or :class:`numpy.ndarray`), scalars, or pytrees whose leaves are arrays or scalars.
|
|
Note that ``None`` is a valid data field, as JAX recognizes this as an empty pytree.
|
|
drop_fields: only referenced if ``nodetype`` is a dataclass. Specify a sequence of
|
|
field names from among ``dataclasses.fields(nodetype)`` to be excluded from pytree
|
|
registration.
|
|
|
|
Returns:
|
|
The input class ``nodetype`` is returned unchanged after being added to JAX's
|
|
pytree registry, so that :func:`register_dataclass` can be used as a decorator.
|
|
|
|
Examples:
|
|
In JAX v0.4.35 or older, you must specify ``data_fields`` and ``meta_fields``
|
|
in order to use this decorator:
|
|
|
|
>>> import jax
|
|
>>> from dataclasses import dataclass
|
|
>>> from functools import partial
|
|
...
|
|
>>> @partial(jax.tree_util.register_dataclass,
|
|
... data_fields=['x', 'y'],
|
|
... meta_fields=['op'])
|
|
... @dataclass
|
|
... class MyStruct:
|
|
... x: jax.Array
|
|
... y: jax.Array
|
|
... op: str
|
|
...
|
|
>>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add')
|
|
>>> m
|
|
MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
|
|
|
|
Starting in JAX v0.4.36, the ``data_fields`` and ``meta_fields`` arguments are optional
|
|
for :func:`~dataclasses.dataclass` inputs, with fields defaulting to ``data_fields``
|
|
unless marked as static using `static` metadata in :func:`dataclasses.field`.
|
|
|
|
>>> import jax
|
|
>>> from dataclasses import dataclass, field
|
|
...
|
|
>>> @jax.tree_util.register_dataclass
|
|
... @dataclass
|
|
... class MyStruct:
|
|
... x: jax.Array # defaults to non-static data field
|
|
... y: jax.Array # defaults to non-static data field
|
|
... op: str = field(metadata=dict(static=True)) # marked as static meta field.
|
|
...
|
|
>>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add')
|
|
>>> m
|
|
MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
|
|
|
|
Once this class is registered, it can be used with functions in :mod:`jax.tree` and
|
|
:mod:`jax.tree_util`:
|
|
|
|
>>> leaves, treedef = jax.tree.flatten(m)
|
|
>>> leaves
|
|
[Array([1., 1., 1.], dtype=float32), Array([0, 1, 2], dtype=int32)]
|
|
>>> treedef
|
|
PyTreeDef(CustomNode(MyStruct[('add',)], [*, *]))
|
|
>>> jax.tree.unflatten(treedef, leaves)
|
|
MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
|
|
|
|
In particular, this registration allows ``m`` to be passed seamlessly through code
|
|
wrapped in :func:`jax.jit` and other JAX transformations, with ``data_fields`` being
|
|
treated as dynamic arguments, and ``meta_fields`` being treated as static arguments:
|
|
|
|
>>> @jax.jit
|
|
... def compiled_func(m):
|
|
... if m.op == 'add':
|
|
... return m.x + m.y
|
|
... else:
|
|
... raise ValueError(f"{m.op=}")
|
|
...
|
|
>>> compiled_func(m)
|
|
Array([1., 2., 3.], dtype=float32)
|
|
"""
|
|
if data_fields is None or meta_fields is None:
|
|
if (data_fields is None) != (meta_fields is None):
|
|
raise TypeError("register_dataclass: data_fields and meta_fields must both be specified"
|
|
f" when either is specified. Got {data_fields=} {meta_fields=}.")
|
|
if not dataclasses.is_dataclass(nodetype):
|
|
raise TypeError("register_dataclass: data_fields and meta_fields are required when"
|
|
f" nodetype is not a dataclass. Got {nodetype=}.")
|
|
data_fields = [
|
|
f.name
|
|
for f in dataclasses.fields(nodetype)
|
|
if not f.metadata.get("static", False)
|
|
]
|
|
meta_fields = [
|
|
f.name
|
|
for f in dataclasses.fields(nodetype)
|
|
if f.metadata.get("static", False)
|
|
]
|
|
|
|
assert meta_fields is not None
|
|
assert data_fields is not None
|
|
|
|
# Store inputs as immutable tuples in this scope, because we close over them
|
|
# for later evaluation. This prevents potentially confusing behavior if the
|
|
# caller were to pass in lists that are later mutated.
|
|
meta_fields = tuple(meta_fields)
|
|
data_fields = tuple(data_fields)
|
|
|
|
if dataclasses.is_dataclass(nodetype):
|
|
init_fields = {f.name for f in dataclasses.fields(nodetype) if f.init}
|
|
init_fields.difference_update(drop_fields)
|
|
if {*meta_fields, *data_fields} != init_fields:
|
|
msg = (
|
|
"data_fields and meta_fields must include all dataclass fields with"
|
|
" ``init=True`` and only them."
|
|
)
|
|
if missing := init_fields - {*meta_fields, *data_fields}:
|
|
msg += (
|
|
f" Missing fields: {missing}. Add them to drop_fields to suppress"
|
|
" this error."
|
|
)
|
|
if unexpected := {*meta_fields, *data_fields} - init_fields:
|
|
msg += f" Unexpected fields: {unexpected}."
|
|
raise ValueError(msg)
|
|
|
|
if overlap := set(data_fields) & set(meta_fields):
|
|
raise ValueError(
|
|
"data_fields and meta_fields must not overlap. Overlapping fields:"
|
|
f" {overlap}."
|
|
)
|
|
|
|
def unflatten_func(meta, data):
|
|
meta_args = tuple(zip(meta_fields, meta))
|
|
data_args = tuple(zip(data_fields, data))
|
|
kwargs = dict(meta_args + data_args)
|
|
return nodetype(**kwargs)
|
|
|
|
def flatten_func(x):
|
|
meta = tuple(getattr(x, name) for name in meta_fields)
|
|
data = tuple(getattr(x, name) for name in data_fields)
|
|
return data, meta
|
|
|
|
for registry in _all_registries:
|
|
registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields))
|
|
_registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func)
|
|
return nodetype
|
|
|
|
|
|
register_pytree_with_keys(
|
|
collections.OrderedDict,
|
|
lambda x: (tuple((DictKey(k), x[k]) for k in x.keys()), tuple(x.keys())),
|
|
lambda keys, values: collections.OrderedDict(safe_zip(keys, values)),
|
|
)
|
|
|
|
def _flatten_defaultdict_with_keys(d):
|
|
keys = tuple(sorted(d))
|
|
return tuple((DictKey(k), d[k]) for k in keys), (d.default_factory, keys)
|
|
|
|
register_pytree_with_keys(
|
|
collections.defaultdict,
|
|
_flatten_defaultdict_with_keys,
|
|
lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values)),
|
|
)
|
|
|
|
|
|
@export
|
|
def register_static(cls: type[H]) -> type[H]:
|
|
"""Registers `cls` as a pytree with no leaves.
|
|
|
|
Instances are treated as static by :func:`jax.jit`, :func:`jax.pmap`, etc. This can
|
|
be an alternative to labeling inputs as static using ``jit``'s ``static_argnums``
|
|
and ``static_argnames`` kwargs, ``pmap``'s ``static_broadcasted_argnums``, etc.
|
|
|
|
Args:
|
|
cls: type to be registered as static. Must be hashable, as defined in
|
|
https://docs.python.org/3/glossary.html#term-hashable.
|
|
|
|
Returns:
|
|
The input class ``cls`` is returned unchanged after being added to JAX's
|
|
pytree registry. This allows ``register_static`` to be used as a decorator.
|
|
|
|
Examples:
|
|
>>> import jax
|
|
>>> @jax.tree_util.register_static
|
|
... class StaticStr(str):
|
|
... pass
|
|
|
|
This static string can now be used directly in :func:`jax.jit`-compiled
|
|
functions, without marking the variable static using ``static_argnums``:
|
|
|
|
>>> @jax.jit
|
|
... def f(x, y, s):
|
|
... return x + y if s == 'add' else x - y
|
|
...
|
|
>>> f(1, 2, StaticStr('add'))
|
|
Array(3, dtype=int32, weak_type=True)
|
|
"""
|
|
flatten = lambda obj: ((), obj)
|
|
unflatten = lambda obj, empty_iter_children: obj
|
|
register_pytree_with_keys(cls, flatten, unflatten)
|
|
return cls
|
|
|
|
|
|
@export
|
|
def tree_flatten_with_path(
|
|
tree: Any, is_leaf: Callable[..., bool] | None = None,
|
|
is_leaf_takes_path: bool = False,
|
|
) -> tuple[list[tuple[KeyPath, Any]], PyTreeDef]:
|
|
"""Alias of :func:`jax.tree.flatten_with_path`."""
|
|
is_leaf_with_kp: Callable[[Any, Any], bool] | None = is_leaf
|
|
if not is_leaf_takes_path and is_leaf is not None:
|
|
is_leaf_with_kp = lambda _, x: is_leaf(x)
|
|
return default_registry.flatten_with_path(tree, is_leaf_with_kp)
|
|
|
|
|
|
@export
|
|
def tree_leaves_with_path(
|
|
tree: Any, is_leaf: Callable[..., bool] | None = None,
|
|
is_leaf_takes_path: bool = False,
|
|
) -> list[tuple[KeyPath, Any]]:
|
|
"""Alias of :func:`jax.tree.leaves_with_path`."""
|
|
return tree_flatten_with_path(tree, is_leaf, is_leaf_takes_path)[0]
|
|
generate_key_paths = tree_leaves_with_path
|
|
|
|
|
|
@export
|
|
def tree_map_with_path(
|
|
f: Callable[..., Any],
|
|
tree: Any,
|
|
*rest: Any,
|
|
is_leaf: Callable[..., bool] | None = None,
|
|
is_leaf_takes_path: bool = False,
|
|
) -> Any:
|
|
"""Alias of :func:`jax.tree.map_with_path`."""
|
|
keypath_leaves, treedef = tree_flatten_with_path(
|
|
tree, is_leaf, is_leaf_takes_path
|
|
)
|
|
keypath_leaves = list(zip(*keypath_leaves))
|
|
all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest]
|
|
return treedef.unflatten(f(*xs) for xs in zip(*all_keypath_leaves))
|
|
|
|
|
|
def _child_keys(pytree: Any) -> KeyPath:
|
|
assert not treedef_is_strict_leaf(tree_structure(pytree))
|
|
return tuple(k for k, _ in flatten_one_level_with_keys(pytree)[0])
|
|
|
|
|
|
def _prefix_error(
|
|
key_path: KeyPath,
|
|
prefix_tree: Any,
|
|
full_tree: Any,
|
|
is_leaf: Callable[[Any], bool] | None = None,
|
|
) -> Iterable[Callable[[str], ValueError]]:
|
|
# A leaf is a valid prefix of any tree:
|
|
if treedef_is_strict_leaf(tree_structure(prefix_tree, is_leaf=is_leaf)):
|
|
return
|
|
|
|
# The subtrees may disagree because their roots are of different types:
|
|
if type(prefix_tree) != type(full_tree):
|
|
yield lambda name: ValueError(
|
|
"pytree structure error: different types at key path\n"
|
|
f" {name}{keystr(key_path)}\n"
|
|
f"At that key path, the prefix pytree {name} has a subtree of type\n"
|
|
f" {type(prefix_tree)}\n"
|
|
f"but at the same key path the full pytree has a subtree of different type\n"
|
|
f" {type(full_tree)}.")
|
|
return # don't look for more errors in this subtree
|
|
|
|
# Or they may disagree if their roots have different numbers or keys of
|
|
# children. Because both prefix_tree and full_tree have the same type at this
|
|
# point, and because prefix_tree is not a leaf, each can be flattened once:
|
|
prefix_tree_children, prefix_tree_meta = flatten_one_level(prefix_tree)
|
|
full_tree_children, full_tree_meta = flatten_one_level(full_tree)
|
|
prefix_tree_children = tuple(prefix_tree_children)
|
|
full_tree_children = tuple(full_tree_children)
|
|
prefix_tree_keys = _child_keys(prefix_tree)
|
|
full_tree_keys = _child_keys(full_tree)
|
|
# First we check special case types (list and tuple, though if they were
|
|
# pytrees we could check strings and sets here, basically Sequences) so that
|
|
# we can report length disagreement rather than integer keys:
|
|
if isinstance(prefix_tree, (list, tuple)):
|
|
if len(prefix_tree) != len(full_tree):
|
|
ty = type(prefix_tree)
|
|
yield lambda name: ValueError(
|
|
f"pytree structure error: different lengths of {ty.__name__} at key path\n"
|
|
f" {name}{keystr(key_path)}\n"
|
|
f"At that key path, the prefix pytree {name} has a subtree of type "
|
|
f"{ty.__name__} of length {len(prefix_tree)}, but the full pytree "
|
|
f"has a subtree of the same type but of length {len(full_tree)}.")
|
|
return # don't look for more errors in this subtree
|
|
else:
|
|
# Next we handle the general case of checking child keys.
|
|
try:
|
|
diff = set(prefix_tree_keys).symmetric_difference(set(full_tree_keys))
|
|
except:
|
|
diff = None
|
|
if len(prefix_tree_children) != len(full_tree_children):
|
|
yield lambda name: ValueError(
|
|
"pytree structure error: different numbers of pytree children at key path\n"
|
|
f" {name}{keystr(key_path)}\n"
|
|
f"At that key path, the prefix pytree {name} has a subtree of type\n"
|
|
f" {type(prefix_tree)}\n"
|
|
f"with {len(prefix_tree_children)} child keys\n"
|
|
f" {' '.join(str(k.key) for k in prefix_tree_keys)}\n"
|
|
f"but at the same key path the full pytree has a subtree of the same "
|
|
f"type but with {len(full_tree_children)} child keys\n"
|
|
f" {' '.join(str(k.key) for k in full_tree_keys)}\n"
|
|
+ ("" if diff is None else
|
|
f"so the symmetric difference on key sets is\n"
|
|
f" {' '.join(str(k.key) for k in diff)}"))
|
|
return # don't look for more errors in this subtree
|
|
|
|
# Or they may disagree if their roots have different pytree metadata:
|
|
if prefix_tree_meta != full_tree_meta:
|
|
prefix_tree_meta_str = str(prefix_tree_meta)
|
|
full_tree_meta_str = str(full_tree_meta)
|
|
metadata_diff = textwrap.indent(
|
|
"\n".join(
|
|
difflib.ndiff(prefix_tree_meta_str.splitlines(),
|
|
full_tree_meta_str.splitlines())),
|
|
prefix=" ")
|
|
yield lambda name: ValueError(
|
|
"pytree structure error: different pytree metadata at key path\n"
|
|
f" {name}{keystr(key_path)}\n"
|
|
f"At that key path, the prefix pytree {name} has a subtree of type\n"
|
|
f" {type(prefix_tree)}\n"
|
|
f"with metadata\n"
|
|
f" {prefix_tree_meta_str}\n"
|
|
f"but at the same key path the full pytree has a subtree of the same "
|
|
f"type but with metadata\n"
|
|
f" {full_tree_meta_str}\n"
|
|
f"so the diff in the metadata at these pytree nodes is\n"
|
|
f"{metadata_diff}")
|
|
return # don't look for more errors in this subtree
|
|
|
|
# If the root types and numbers of children agree, there must be an error
|
|
# in a subtree, so recurse:
|
|
assert prefix_tree_keys == full_tree_keys, \
|
|
("equal pytree nodes gave differing prefix_tree_keys: "
|
|
f"{prefix_tree_keys} and {full_tree_keys}")
|
|
for k, t1, t2 in zip(prefix_tree_keys, prefix_tree_children, full_tree_children):
|
|
yield from _prefix_error((*key_path, k), t1, t2)
|
|
|
|
# === flat tree ===
|
|
|
|
class FlatTree:
|
|
"""A FlatTree stores a treedef and a flat list of values. It's meant to be
|
|
isomorphic to the corresponding pytree but we can map over it more easily.
|
|
Compared to `tree_map`, FlatTree.map has these benefits:
|
|
1. It doesn't touch user flatten/unflatten code (which shouldn't have side
|
|
effects but sometimes does in practice).
|
|
2. It can be faster, because it skips the recursive traversal.
|
|
3. It actually obeys the functor rules. For example,
|
|
`flat_tree.map(lambda x: (f(x), g(x))).unzip2()[0]` will give
|
|
the same result as `flat_tree.map(f)`, whereas in the `tree_map` version
|
|
the tuple-returning function would change the tree structure and `unzip`
|
|
wouldn't be able to recover it.
|
|
"""
|
|
# `FlatTree` constructor is private. Use `FlatTree.flatten` instead
|
|
def __init__(self, vals, treedef: PyTreeDef, statics,
|
|
registry=tracing_registry):
|
|
self.registry = registry
|
|
assert isinstance(treedef, pytree.PyTreeDef)
|
|
if not isinstance(vals, tuple):
|
|
vals = tuple(vals)
|
|
self.vals = tuple(vals)
|
|
self.tree = treedef
|
|
self.statics = statics # tree-prefix tuple-dict-tree of bools
|
|
|
|
def __eq__(self, other):
|
|
return (isinstance(other, FlatTree) and self.vals == other.vals
|
|
and self.tree == other.tree and self.statics == other.statics
|
|
and self.registry is other.registry)
|
|
|
|
def __hash__(self):
|
|
return hash((self.vals, self.tree))
|
|
|
|
def map(self, f: Callable) -> FlatTree:
|
|
return self.update(f(x) for x in self.vals)
|
|
|
|
def map2(self: FlatTree, f: Callable, t2: Sequence[Any]) -> FlatTree:
|
|
n = len(self)
|
|
assert len(t2) == n
|
|
return self.update(f(x1, x2) for x1, x2 in zip(self.vals, list(t2)))
|
|
|
|
def map3(
|
|
self: FlatTree, f: Callable, t2: Sequence[Any], t3: Sequence[Any]) -> FlatTree:
|
|
n = len(self)
|
|
assert len(t2) == n and len(t3) == n
|
|
return self.update(f(x1, x2, x3)
|
|
for x1, x2, x3 in zip(self.vals, list(t2), list(t3)))
|
|
|
|
def unzip2(self: FlatTree) -> tuple[FlatTree, FlatTree]:
|
|
ys = []
|
|
zs = []
|
|
for y, z in self.vals:
|
|
ys.append(y)
|
|
zs.append(z)
|
|
return self.update(ys), self.update(zs)
|
|
|
|
# TODO: add other helpers like map3, zip, unzip3 etc. as needed
|
|
|
|
@staticmethod
|
|
def pack(tree, registry=tracing_registry):
|
|
# We could generalize this to arbitrary pytrees of FlatTree but tuples/dicts
|
|
# are sufficient for now.
|
|
if isinstance(tree, FlatTree):
|
|
return tree
|
|
elif isinstance(tree, tuple):
|
|
vals = []
|
|
trees = []
|
|
staticss = []
|
|
for child_tree in tree:
|
|
child = FlatTree.pack(child_tree, registry=registry)
|
|
vals.extend(child.vals)
|
|
trees.append(child.tree)
|
|
staticss.append(child.statics)
|
|
return FlatTree(vals, pytree.treedef_tuple(registry, trees),
|
|
tuple(staticss), registry=registry)
|
|
elif isinstance(tree, dict):
|
|
# only empty case handled for now
|
|
if tree == {}:
|
|
return FlatTree.flatten({}, registry=registry)
|
|
else:
|
|
assert False
|
|
else:
|
|
assert False, type(tree)
|
|
|
|
def unpack(self: FlatTree) -> tuple[FlatTree, ...]:
|
|
# TODO: this is O(N) not O(1) (with N as the number of leaves). If it
|
|
# becomes a problem we can fix it with a fancier data structure.
|
|
# TODO(dougalm): assert that we're dealing with a tuple
|
|
trees = treedef_children(self.tree)
|
|
children = []
|
|
offset = 0
|
|
for i, tree in enumerate(trees):
|
|
statics = False if isinstance(self.statics, bool) else self.statics[i]
|
|
new_offset = offset + tree.num_leaves
|
|
children.append(FlatTree(self.vals[offset:new_offset], tree, statics,
|
|
registry=self.registry))
|
|
offset = new_offset
|
|
return tuple(children)
|
|
|
|
def with_aux(self:FlatTree, aux:Any) -> FlatTree:
|
|
return FlatTree.pack((self, FlatTree.flatten(Static(aux))))
|
|
|
|
def unpack_aux(self:FlatTree) -> tuple[FlatTree, Any]:
|
|
x, aux = self.unpack()
|
|
return x, aux.unflatten().val
|
|
|
|
@staticmethod
|
|
def flatten(tree: PyTree, is_leaf=None, registry=tracing_registry) -> FlatTree:
|
|
vals, tree = registry.flatten(tree, is_leaf)
|
|
return FlatTree(vals, tree, False, registry=registry)
|
|
|
|
@staticmethod
|
|
def flatten_args(*arg_trees: PyTree, registry=tracing_registry) -> FlatTree:
|
|
return FlatTree.flatten((arg_trees, {}), registry=registry)
|
|
|
|
@staticmethod
|
|
def flatten_static_argnums(args, static_argnums, registry=tracing_registry):
|
|
if not static_argnums:
|
|
return FlatTree.flatten(args, registry=registry)
|
|
else:
|
|
assert isinstance(args, tuple)
|
|
num_args = len(args)
|
|
static_argnums = [i % num_args if i < 0 else i for i in static_argnums]
|
|
statics = tuple(i in static_argnums for i, _ in enumerate(args))
|
|
tree_with_statics = tuple(
|
|
Static(x) if static else x for static, x in zip(statics, args))
|
|
vals, treedef = registry.flatten(tree_with_statics)
|
|
return FlatTree(vals, treedef, statics=statics, registry=registry)
|
|
|
|
@staticmethod
|
|
def flatten_static_argnames(kwargs, static_argnames,
|
|
registry=tracing_registry):
|
|
if not static_argnames:
|
|
return FlatTree.flatten(kwargs, registry=registry)
|
|
else:
|
|
assert isinstance(kwargs, dict)
|
|
statics = {k : k in static_argnames for k, _ in kwargs.items()}
|
|
tree_with_statics = {k : Static(v) if statics[k] else v
|
|
for k, v in kwargs.items()}
|
|
vals, treedef = registry.flatten(tree_with_statics)
|
|
return FlatTree(vals, treedef, statics=statics, registry=registry)
|
|
|
|
@staticmethod
|
|
def flatten_static_argnums_argnames(
|
|
args, kwargs, static_argnums, static_argnames,
|
|
registry=tracing_registry):
|
|
return FlatTree.pack((
|
|
FlatTree.flatten_static_argnums(args, static_argnums,
|
|
registry=registry),
|
|
FlatTree.flatten_static_argnames(kwargs, static_argnames,
|
|
registry=registry)),
|
|
registry=registry)
|
|
|
|
def unflatten(self) -> PyTree:
|
|
pytree = self.tree.unflatten(self.vals)
|
|
return unwrap_statics(pytree, self.statics)
|
|
|
|
@property
|
|
def tree_without_statics(self):
|
|
return filter_statics_from_treedef(self.registry, self.tree, self.statics)
|
|
|
|
def update(self, new_vals) -> FlatTree:
|
|
# `new_vals` can be a generator because `FlatTree` forces it to a tuple
|
|
new = FlatTree(new_vals, self.tree, self.statics, registry=self.registry)
|
|
assert len(self.vals) == len(new.vals)
|
|
return new
|
|
|
|
@cached_property
|
|
def paths(self) -> FlatTree:
|
|
# TODO(dougalm): find a way to do this without roundtripping
|
|
try:
|
|
paths, _ = unzip2(self.registry.flatten_with_path(self.unflatten())[0])
|
|
assert len(paths) == len(self.vals)
|
|
return self.update(paths)
|
|
except:
|
|
return self.update([()] * len(self.vals)) # not our fault
|
|
|
|
def __len__(self):
|
|
return self.len
|
|
|
|
@cached_property
|
|
def len(self):
|
|
return self.tree.num_leaves
|
|
|
|
def __iter__(self):
|
|
return self.vals.__iter__()
|
|
|
|
def __getitem__(self, i):
|
|
assert False, "todo"
|
|
|
|
def unwrap_statics(pytree, statics):
|
|
if statics is False:
|
|
return pytree
|
|
elif statics is True:
|
|
return pytree.val # pytree should be a `Static` object
|
|
elif isinstance(pytree, tuple):
|
|
return tuple(unwrap_statics(p, s) for p, s in zip(pytree, statics))
|
|
elif isinstance(pytree, dict):
|
|
return {k : unwrap_statics(p, statics[k]) for k, p in pytree.items()}
|
|
else:
|
|
assert False, "unreachable"
|
|
|
|
def filter_statics_from_treedef(registry, treedef, statics):
|
|
if statics is False:
|
|
return treedef
|
|
elif statics is True:
|
|
assert False, "unreachable"
|
|
elif isinstance(statics, tuple):
|
|
filtered = tuple(
|
|
filter_statics_from_treedef(registry, td, s)
|
|
for td, s in zip(treedef.children(), statics) if s is not True)
|
|
return treedef.from_node_data_and_children(registry, treedef.node_data(), filtered)
|
|
elif isinstance(statics, dict):
|
|
ty, keys = treedef.node_data()
|
|
filtered_keys, filtered_subtrees = unzip2(
|
|
(k, filter_statics_from_treedef(registry, td, statics[k]))
|
|
for td, k in zip(treedef.children(), keys) if statics[k] is not True)
|
|
return treedef.from_node_data_and_children(registry, (ty, filtered_keys), filtered_subtrees)
|
|
else:
|
|
assert False, "unreachable"
|
|
|
|
@register_static
|
|
@dataclasses.dataclass(frozen=True)
|
|
class Static:
|
|
val: Any
|
|
|
|
def __eq__(self, other):
|
|
return (type(other) is Static and type(self.val) is type(other.val) and
|
|
self.val == other.val)
|
|
|
|
|
|
def _ensure_inbounds(allow_invalid: bool, num_args: int, argnums: Sequence[int]
|
|
) -> tuple[int, ...]:
|
|
"""Ensure argnum is within bounds. Also resolves negative argnums."""
|
|
result = []
|
|
for i in argnums:
|
|
if i >= num_args and allow_invalid: continue
|
|
if not -num_args <= i < num_args:
|
|
raise ValueError(
|
|
"Positional argument indices, e.g. for `static_argnums`, must have "
|
|
"value greater than or equal to -len(args) and less than len(args), "
|
|
f"but got value {i} for len(args) == {num_args}.")
|
|
result.append(i % num_args) # Resolve negative
|
|
return tuple(result)
|