797 lines
25 KiB
Python
797 lines
25 KiB
Python
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
import abc
|
|
from collections.abc import Callable, Iterable, Iterator, Sequence
|
|
import functools
|
|
import threading
|
|
from functools import partial
|
|
import itertools as it
|
|
import logging
|
|
import math
|
|
import operator
|
|
from typing import (Any, Generic, ParamSpec, Protocol, SupportsIndex,
|
|
Type, TypeVar, overload, TYPE_CHECKING, cast)
|
|
import weakref
|
|
|
|
import numpy as np
|
|
|
|
from jax._src import config
|
|
from jax._src.lib import pytree as lib_pytree
|
|
from jax._src.lib import weakref_lru_cache as lib_weakref_lru_cache
|
|
from jax._src.lib import utils as jaxlib_utils
|
|
from jax._src.lib import jaxlib_extension_version
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
Seq = Sequence
|
|
|
|
# TODO(jakevdp): fix import cycles and import Array.
|
|
Array = Any
|
|
|
|
T = TypeVar("T")
|
|
T1 = TypeVar("T1")
|
|
T2 = TypeVar("T2")
|
|
T3 = TypeVar("T3")
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
# safe_zip cannot yet be fully annotated, so we use a strategy similar
|
|
# to that used for builtins.zip in python/typeshed. This supports
|
|
# return types matching input types for up to three arguments.
|
|
@overload
|
|
def safe_zip(__arg1: Iterable[T1], /) -> list[tuple[T1]]:
|
|
...
|
|
@overload
|
|
def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2], /) -> list[tuple[T1, T2]]:
|
|
...
|
|
@overload
|
|
def safe_zip(__arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3], /) -> list[tuple[T1, T2, T3]]:
|
|
...
|
|
@overload
|
|
def safe_zip(__arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], /, *args) -> list[tuple[Any, ...]]:
|
|
...
|
|
|
|
def safe_zip(*args):
|
|
"""
|
|
Like builtin :func:`zip`, but with additional safety checks.
|
|
|
|
The differences from :func:`zip` are:
|
|
|
|
- :func:`safe_zip` checks that at least one argument is provided.
|
|
- :func:`safe_zip` checks that all arguments have the same length.
|
|
- :func:`safe_zip` returns an eagerly-evaluated list instead of a
|
|
lazily-evaluated iterator.
|
|
"""
|
|
if not args:
|
|
raise TypeError("safe_zip requires at least 1 argument.")
|
|
return list(zip(*args, strict=True))
|
|
else:
|
|
safe_zip = jaxlib_utils.safe_zip
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
# safe_map cannot yet be fully annotated, so we use a strategy similar
|
|
# to that used for builtins.map in python/typeshed. This supports
|
|
# checking input types for the callable with up to three arguments.
|
|
@overload
|
|
def safe_map(f: Callable[[T1], T], __arg1: Iterable[T1], /) -> list[T]: ...
|
|
|
|
@overload
|
|
def safe_map(f: Callable[[T1, T2], T], __arg1: Iterable[T1], __arg2: Iterable[T2], /) -> list[T]: ...
|
|
|
|
@overload
|
|
def safe_map(f: Callable[[T1, T2, T3], T], __arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3], /) -> list[T]: ...
|
|
|
|
@overload
|
|
def safe_map(f: Callable[..., T], __arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], /, *args) -> list[T]: ...
|
|
|
|
def safe_map(f, *args):
|
|
args = list(map(list, args))
|
|
n = len(args[0])
|
|
for arg in args[1:]:
|
|
assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
|
|
return list(map(f, *args))
|
|
|
|
else:
|
|
safe_map = jaxlib_utils.safe_map
|
|
|
|
if TYPE_CHECKING:
|
|
@overload
|
|
def foreach(f: Callable[[T1], Any], __arg1: Iterable[T1], /) -> None: ...
|
|
|
|
@overload
|
|
def foreach(f: Callable[[T1, T2], Any], __arg1: Iterable[T1], __arg2: Iterable[T2], /) -> None: ...
|
|
|
|
@overload
|
|
def foreach(f: Callable[[T1, T2, T3], Any], __arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3], /) -> None: ...
|
|
|
|
@overload
|
|
def foreach(f: Callable[..., Any], __arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], /, *args) -> None: ...
|
|
|
|
def foreach(f, *args):
|
|
safe_map(f, *args)
|
|
return None
|
|
|
|
else:
|
|
foreach = jaxlib_utils.foreach
|
|
|
|
|
|
def unzip2(xys: Iterable[tuple[T1, T2]]
|
|
) -> tuple[tuple[T1, ...], tuple[T2, ...]]:
|
|
"""Unzip sequence of length-2 tuples into two tuples."""
|
|
# Note: we deliberately don't use zip(*xys) because it is lazily evaluated,
|
|
# is too permissive about inputs, and does not guarantee a length-2 output.
|
|
xs: list[T1] = []
|
|
ys: list[T2] = []
|
|
for x, y in xys:
|
|
xs.append(x)
|
|
ys.append(y)
|
|
return tuple(xs), tuple(ys)
|
|
|
|
def unzip3(xyzs: Iterable[tuple[T1, T2, T3]]
|
|
) -> tuple[tuple[T1, ...], tuple[T2, ...], tuple[T3, ...]]:
|
|
"""Unzip sequence of length-3 tuples into three tuples."""
|
|
# Note: we deliberately don't use zip(*xyzs) because it is lazily evaluated,
|
|
# is too permissive about inputs, and does not guarantee a length-3 output.
|
|
xs: list[T1] = []
|
|
ys: list[T2] = []
|
|
zs: list[T3] = []
|
|
for x, y, z in xyzs:
|
|
xs.append(x)
|
|
ys.append(y)
|
|
zs.append(z)
|
|
return tuple(xs), tuple(ys), tuple(zs)
|
|
|
|
def subvals(lst: Sequence[T], replace: Iterable[tuple[int, T]]) -> tuple[T, ...]:
|
|
"""Substitute values within a list."""
|
|
lst = list(lst)
|
|
for i, v in replace:
|
|
lst[i] = v
|
|
return tuple(lst)
|
|
|
|
def split_list(args: Sequence[T], ns: Sequence[int]) -> list[list[T]]:
|
|
"""Split list into sublists of the specified sizes."""
|
|
args = list(args)
|
|
lists = []
|
|
for n in ns:
|
|
lists.append(args[:n])
|
|
args = args[n:]
|
|
lists.append(args)
|
|
return lists
|
|
|
|
def split_list_checked(args: Sequence[T], ns: Sequence[int]) -> list[list[T]]:
|
|
"""Split list into sublists of the specified sizes."""
|
|
args = list(args)
|
|
assert sum(ns) == len(args) and all(n >= 0 for n in ns)
|
|
lists = []
|
|
for n in ns:
|
|
lists.append(args[:n])
|
|
args = args[n:]
|
|
return lists
|
|
|
|
def partition_list(bs: Sequence[bool], l: Sequence[T]) -> tuple[list[T], list[T]]:
|
|
"""Partition a list into two based on a mask."""
|
|
assert len(bs) == len(l)
|
|
lists: tuple[list[T], list[T]] = ([], [])
|
|
for b, x in zip(bs, l):
|
|
lists[b].append(x)
|
|
return lists
|
|
|
|
def merge_lists(bs: Sequence[bool], l0: Sequence[T1], l1: Sequence[T2]
|
|
) -> list[T1 | T2]:
|
|
"""Merge the elements of two lists based on a mask."""
|
|
assert sum(bs) == len(l1) and len(bs) - sum(bs) == len(l0)
|
|
i0, i1 = iter(l0), iter(l1)
|
|
out: list[T1 | T2] = [next(i1) if b else next(i0) for b in bs]
|
|
sentinel = object()
|
|
assert next(i0, sentinel) is next(i1, sentinel) is sentinel
|
|
return out
|
|
|
|
def subs_list(
|
|
subs: Sequence[int | None], src: Sequence[T], base: Sequence[T],
|
|
) -> list[T]:
|
|
base_ = iter(base)
|
|
out = [src[i] if i is not None else next(base_) for i in subs]
|
|
sentinel = object()
|
|
assert next(base_, sentinel) is sentinel
|
|
return out
|
|
|
|
def subs_list2(
|
|
subs1: Sequence[int | None], subs2: Sequence[int | None],
|
|
src1: Sequence[T], src2: Sequence[T], base: Sequence[T],
|
|
) -> list[T]:
|
|
assert len(subs1) == len(subs2)
|
|
base_ = iter(base)
|
|
out = [src1[f1] if f1 is not None else src2[f2] if f2 is not None else
|
|
next(base_) for f1, f2, in zip(subs1, subs2)]
|
|
sentinel = object()
|
|
assert next(base_, sentinel) is sentinel
|
|
return out
|
|
|
|
def concatenate(xs: Iterable[Sequence[T]]) -> list[T]:
|
|
"""Concatenates/flattens a list of lists."""
|
|
return list(it.chain.from_iterable(xs))
|
|
|
|
flatten = concatenate
|
|
|
|
_unflatten_done = object()
|
|
|
|
def unflatten(xs: Iterable[T], ns: Sequence[int]) -> list[list[T]]:
|
|
"""Splits `xs` into subsequences of lengths `ns`.
|
|
|
|
Unlike `split_list`, the `sum(ns)` must be equal to `len(xs)`."""
|
|
xs_iter = iter(xs)
|
|
unflattened = [[next(xs_iter) for _ in range(n)] for n in ns]
|
|
assert next(xs_iter, _unflatten_done) is _unflatten_done
|
|
return unflattened
|
|
|
|
|
|
def curry(f):
|
|
"""Curries arguments of f, returning a function on any remaining arguments.
|
|
|
|
For example:
|
|
>>> f = lambda x, y, z, w: x * y + z * w
|
|
>>> f(2,3,4,5)
|
|
26
|
|
>>> curry(f)(2)(3, 4, 5)
|
|
26
|
|
>>> curry(f)(2, 3)(4, 5)
|
|
26
|
|
>>> curry(f)(2, 3, 4, 5)()
|
|
26
|
|
"""
|
|
return wraps(f)(partial(partial, f))
|
|
|
|
toposort: Callable[[Iterable[Any]], list[Any]]
|
|
toposort = partial(jaxlib_utils.topological_sort, "parents")
|
|
|
|
|
|
def cache(max_size=4096, trace_context_in_key: bool | Callable = True):
|
|
if trace_context_in_key:
|
|
trace_context = (trace_context_in_key if callable(trace_context_in_key)
|
|
else config.trace_context)
|
|
def wrap(f):
|
|
@functools.lru_cache(max_size)
|
|
def cached(_, *args, **kwargs):
|
|
return f(*args, **kwargs)
|
|
|
|
@functools.wraps(f)
|
|
def wrapper(*args, **kwargs):
|
|
if config.check_tracer_leaks.value:
|
|
return f(*args, **kwargs)
|
|
return cached(trace_context(), *args, **kwargs)
|
|
|
|
wrapper = cast(Any, wrapper) # avoids missing-attribute typing errors
|
|
wrapper.cache_clear = cached.cache_clear
|
|
wrapper.cache_info = cached.cache_info
|
|
register_cache(wrapper, str(f))
|
|
return wrapper
|
|
else:
|
|
def wrap(f):
|
|
wrapper = functools.lru_cache(max_size)(f)
|
|
register_cache(wrapper, str(f))
|
|
return wrapper
|
|
return wrap
|
|
|
|
# Maps caches to the name of the callable they apply to. All caches in
|
|
# this dictionary support `cache_clear()`.
|
|
_caches: weakref.WeakKeyDictionary[Any, str] = weakref.WeakKeyDictionary()
|
|
|
|
def register_cache(cache: Any, for_what: str):
|
|
"""Registers a cache with JAX's cache management.
|
|
|
|
Args:
|
|
cache: an object supporting `cache_clear()`, `cache_info()`, and
|
|
`cache_keys()`, like the result of `functools.lru_cache()`.
|
|
for_what: a string to identify what this cache is used for. This is
|
|
used for debugging.
|
|
"""
|
|
_caches[cache] = for_what
|
|
|
|
def clear_all_caches():
|
|
for cache in list(_caches.keys()):
|
|
cache.cache_clear()
|
|
|
|
memoize = cache(max_size=None)
|
|
|
|
def _ignore(): return None
|
|
|
|
P = ParamSpec("P")
|
|
R = TypeVar("R", covariant=True)
|
|
|
|
class WeakrefCachedFunc(Protocol, Generic[P, R]):
|
|
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ...
|
|
def cache_clear(self) -> None: ...
|
|
def cache_info(self) -> lib_weakref_lru_cache.WeakrefLRUCache.WeakrefLRUCacheInfo: ...
|
|
def cache_keys(self) -> list[Any]: ...
|
|
def evict_weakref(self, arg0: Any) -> None: ...
|
|
|
|
@overload
|
|
def weakref_lru_cache(
|
|
f: Callable[P, R], /, *, maxsize: int | None = 2048,
|
|
trace_context_in_key: bool = True, explain: Callable | None = None
|
|
) -> WeakrefCachedFunc[P, R]: ...
|
|
|
|
@overload
|
|
def weakref_lru_cache(
|
|
f: None = None, /, *, maxsize: int | None = 2048,
|
|
trace_context_in_key: bool = True, explain: Callable | None = None
|
|
) -> Callable[[Callable[P, R]], WeakrefCachedFunc[P, R]]: ...
|
|
|
|
def weakref_lru_cache(
|
|
f: Callable[P, R] | None = None, *, maxsize: int | None = 2048,
|
|
trace_context_in_key: bool = True, explain: Callable | None = None
|
|
):
|
|
"""
|
|
Least recently used cache decorator with weakref support.
|
|
|
|
The cache will take a weakref to the first argument of the wrapped function
|
|
and strong refs to all other arguments. In all other respects it should
|
|
behave similar to `functools.lru_cache`. The cache is thread local.
|
|
"""
|
|
kwargs = dict(maxsize=maxsize, trace_context_in_key=trace_context_in_key,
|
|
explain=explain)
|
|
if f is None:
|
|
return lambda g: _weakref_lru_cache(g, **kwargs)
|
|
return _weakref_lru_cache(f, **kwargs)
|
|
|
|
def _weakref_lru_cache(f, maxsize, trace_context_in_key, explain):
|
|
cached_f = lib_weakref_lru_cache.weakref_lru_cache(
|
|
config.trace_context if trace_context_in_key else _ignore, f, maxsize,
|
|
explain = lambda: explain if config.explain_cache_misses.value else None)
|
|
register_cache(cached_f, str(f))
|
|
return cached_f
|
|
|
|
|
|
# Interner from strong keys to weak values, intended for us to intern object
|
|
# construction, thereby making subsequent __eq__ and __hash__ calls cheap and
|
|
# based on object identity.
|
|
#
|
|
# Caution: The interner does not know about the *signature* of the cached
|
|
# function. In particular, if the same argument value can be passed as either
|
|
# an arg or a kwarg, then the interner may store multiple entries for the same
|
|
# logical call. If this troubles you canonicalize the arguments first, e.g.
|
|
# via a wrapper function.
|
|
if jaxlib_extension_version >= 433:
|
|
weak_value_interner = lib_weakref_lru_cache.weak_value_interner
|
|
else:
|
|
def weak_value_interner(f):
|
|
cache = weakref.WeakValueDictionary()
|
|
lock = threading.Lock()
|
|
@functools.wraps(f)
|
|
def wrapper(*args, **kwargs):
|
|
key = (args, frozenset(kwargs.items()))
|
|
with lock:
|
|
result = cache.get(key)
|
|
if result is not None:
|
|
return result
|
|
result = f(*args, **kwargs)
|
|
cache[key] = result
|
|
return result
|
|
return wrapper
|
|
|
|
|
|
def immutable(cls):
|
|
"""Decorator to avoid boilerplate for immutable interned classes."""
|
|
def __deepcopy__(self, memo):
|
|
# Deep copy of a singleton interned object is the identity.
|
|
return self
|
|
cls.__deepcopy__ = __deepcopy__
|
|
|
|
# Pickling calls __getstate__ and __setstate__, but we're assuming the
|
|
# caller will implement __getnewargs_ex__.
|
|
def __getstate__(self):
|
|
return None
|
|
def __setstate__(self, state):
|
|
pass
|
|
cls.__getstate__ = __getstate__
|
|
cls.__setstate__ = __setstate__
|
|
|
|
# Discourage mutation after construction.
|
|
def __setattr__(self, name, value):
|
|
raise AttributeError(f"cannot assign to field {name!r}")
|
|
def __delattr__(self, name):
|
|
raise AttributeError(f"cannot delete field {name!r}")
|
|
cls.__setattr__ = __setattr__
|
|
cls.__delattr__ = __delattr__
|
|
|
|
return cls
|
|
|
|
|
|
# The types of arguments for which `multi_weakref_lru_cache` should keep
|
|
# weak references.
|
|
weakref_cache_key_types: set[Type] = set()
|
|
|
|
|
|
_multi_weakref_registry = lib_pytree.PyTreeRegistry(
|
|
enable_none=False,
|
|
enable_tuple=True,
|
|
enable_namedtuple=False,
|
|
enable_list=False,
|
|
enable_dict=True,
|
|
)
|
|
|
|
|
|
def multi_weakref_lru_cache(
|
|
call: Callable,
|
|
*,
|
|
maxsize: int | None = 2048,
|
|
trace_context_in_key: bool = True,
|
|
):
|
|
"""Least recently used cache decorator with weakref support.
|
|
|
|
Similar to `weakref_lru_cache`, except that it keeps weak references
|
|
to all positional and keyword arguments for which
|
|
`is_weakref_cache_key_type()` is true, and strong references to
|
|
other arguments. The cache entry is removed if any of the weakref
|
|
arguments dies.
|
|
"""
|
|
cached_call = lib_weakref_lru_cache.multi_weakref_lru_cache(
|
|
config.trace_context if trace_context_in_key else _ignore,
|
|
call,
|
|
maxsize=maxsize,
|
|
explain=None,
|
|
registry=_multi_weakref_registry,
|
|
weak_types=weakref_cache_key_types,
|
|
)
|
|
register_cache(cached_call, str(call))
|
|
return cached_call
|
|
|
|
|
|
class Unhashable:
|
|
__slots__ = ["val"]
|
|
|
|
def __init__(self, val):
|
|
self.val = val
|
|
|
|
def __eq__(self, other):
|
|
return self.val == other.val
|
|
|
|
class Hashable:
|
|
__slots__ = ["val"]
|
|
|
|
def __init__(self, val):
|
|
self.val = val
|
|
|
|
def __hash__(self):
|
|
return hash(self.val)
|
|
|
|
def __eq__(self, other):
|
|
return self.val == other.val
|
|
|
|
def wrap_name(transform_name: str, name: str) -> str:
|
|
return f"{transform_name}({name})"
|
|
|
|
|
|
def fun_name(fun: Callable, default_name: str = "<unnamed function>") -> str:
|
|
name = getattr(fun, "__name__", None)
|
|
if name is not None:
|
|
return name
|
|
if isinstance(fun, partial):
|
|
return fun_name(fun.func)
|
|
else:
|
|
return default_name
|
|
|
|
|
|
def fun_qual_name(fun: Callable) -> str:
|
|
qual_name = getattr(fun, "__qualname__", None)
|
|
if qual_name is not None:
|
|
return qual_name
|
|
if isinstance(fun, partial):
|
|
return fun_qual_name(fun.func)
|
|
return fun_name(fun)
|
|
|
|
def canonicalize_axis(axis: SupportsIndex, num_dims: int) -> int:
|
|
"""Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
|
|
axis = operator.index(axis)
|
|
if not -num_dims <= axis < num_dims:
|
|
raise ValueError(f"axis {axis} is out of bounds for array of dimension {num_dims}")
|
|
if axis < 0:
|
|
axis = axis + num_dims
|
|
return axis
|
|
|
|
def canonicalize_axis_tuple(axis: int | Sequence[int] | None, ndim: int, allow_duplicate: bool = False) -> tuple[int, ...]:
|
|
if axis is None:
|
|
return tuple(range(ndim))
|
|
if isinstance(axis, Sequence):
|
|
axis = tuple(canonicalize_axis(i, ndim) for i in axis)
|
|
if not allow_duplicate and len(set(axis)) != len(axis):
|
|
raise ValueError(f"repeated axis: {axis}")
|
|
return axis
|
|
else:
|
|
return (canonicalize_axis(axis, ndim),)
|
|
|
|
def moveaxis(x: Array, src: int | Sequence[int], dst: int | Sequence[int]) -> Array:
|
|
if src == dst:
|
|
return x
|
|
if isinstance(src, int):
|
|
src = (src,)
|
|
if isinstance(dst, int):
|
|
dst = (dst,)
|
|
src = [canonicalize_axis(a, x.ndim) for a in src]
|
|
dst = [canonicalize_axis(a, x.ndim) for a in dst]
|
|
perm = [i for i in range(np.ndim(x)) if i not in src]
|
|
for d, s in sorted(zip(dst, src)):
|
|
perm.insert(d, s)
|
|
return x.transpose(perm)
|
|
|
|
def ceil_of_ratio(x: int, y: int) -> int:
|
|
return -(-x // y)
|
|
|
|
|
|
def wraps(
|
|
wrapped: Callable,
|
|
namestr: str | None = None,
|
|
docstr: str | None = None,
|
|
**kwargs,
|
|
) -> Callable[[T], T]:
|
|
"""
|
|
Like functools.wraps, but with finer-grained control over the name and docstring
|
|
of the resulting function.
|
|
"""
|
|
def wrapper(fun: T) -> T:
|
|
try:
|
|
name = fun_name(wrapped)
|
|
doc = getattr(wrapped, "__doc__", "") or ""
|
|
fun.__dict__.update(getattr(wrapped, "__dict__", {}))
|
|
fun.__annotations__ = getattr(wrapped, "__annotations__", {})
|
|
fun.__name__ = name if namestr is None else namestr.format(fun=name) # pyrefly: ignore[missing-attribute]
|
|
fun.__module__ = getattr(wrapped, "__module__", "<unknown module>")
|
|
fun.__doc__ = (doc if docstr is None
|
|
else docstr.format(fun=name, doc=doc, **kwargs))
|
|
fun.__qualname__ = getattr(wrapped, "__qualname__", fun.__name__) # pyrefly: ignore[missing-attribute]
|
|
fun.__wrapped__ = wrapped # pyrefly: ignore[missing-attribute]
|
|
except Exception:
|
|
pass
|
|
return fun
|
|
return wrapper
|
|
|
|
def tuple_insert(t: tuple[T, ...], idx: int, val: T) -> tuple[T, ...]:
|
|
assert 0 <= idx <= len(t), (idx, len(t))
|
|
return t[:idx] + (val,) + t[idx:]
|
|
|
|
def tuple_delete(t: tuple[T, ...], idx: int) -> tuple[T, ...]:
|
|
assert 0 <= idx < len(t), (idx, len(t))
|
|
return t[:idx] + t[idx + 1:]
|
|
|
|
def tuple_update(t: tuple[T, ...], idx: int, val: T) -> tuple[T, ...]:
|
|
assert 0 <= idx < len(t), (idx, len(t))
|
|
return t[:idx] + (val,) + t[idx+1:]
|
|
|
|
class HashableFunction:
|
|
"""Decouples function equality and hash from its identity.
|
|
|
|
Local lambdas and function defs are reallocated on each function call, making
|
|
the functions created on different calls compare as unequal. This breaks our
|
|
caching logic, which should really only care about comparing the semantics and
|
|
not actual identity.
|
|
|
|
This class makes it possible to compare different functions based on their
|
|
semantics. The parts that are taken into account are: the bytecode of the
|
|
wrapped function (which is cached by the CPython interpreter and is stable
|
|
across the invocations of the surrounding function), and `closure` which
|
|
should contain all values in scope that affect the function semantics. In
|
|
particular `closure` should contain all elements of the function closure, or
|
|
it should be possible to derive the relevant elements of the true function
|
|
closure based solely on the contents of the `closure` argument (e.g. in case
|
|
some closed-over values are not hashable, but are entirely determined by
|
|
hashable locals).
|
|
"""
|
|
|
|
def __init__(self, f, closure):
|
|
self.f = f
|
|
self.closure = closure
|
|
|
|
def __eq__(self, other):
|
|
return (type(other) is HashableFunction and
|
|
self.f.__code__ == other.f.__code__ and
|
|
self.closure == other.closure)
|
|
|
|
def __hash__(self):
|
|
return hash((self.f.__code__, self.closure))
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
return self.f(*args, **kwargs)
|
|
|
|
def __repr__(self):
|
|
return f'<hashable {self.f.__name__} with closure={self.closure}>'
|
|
|
|
|
|
class HashablePartial:
|
|
def __init__(self, f, *args, **kwargs):
|
|
self.f = f
|
|
self.args = args
|
|
self.kwargs = kwargs
|
|
|
|
def __eq__(self, other):
|
|
return (type(other) is HashablePartial and
|
|
self.f.__code__ == other.f.__code__ and
|
|
self.args == other.args and self.kwargs == other.kwargs)
|
|
|
|
def __hash__(self):
|
|
kwargs = tuple(sorted(self.kwargs.items(), key=lambda kv: kv[0]))
|
|
return hash((self.f.__code__, self.args, kwargs))
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
return self.f(*self.args, *args, **self.kwargs, **kwargs)
|
|
|
|
def maybe_named_axis(axis, if_pos, if_named):
|
|
try:
|
|
pos = operator.index(axis)
|
|
except TypeError:
|
|
return if_named(axis)
|
|
else:
|
|
return if_pos(pos)
|
|
|
|
def distributed_debug_log(*pairs):
|
|
"""Format and log `pairs` if config.jax_distributed_debug is enabled.
|
|
|
|
Args:
|
|
pairs: A sequence of label/value pairs to log. The first pair is treated as
|
|
a heading for subsequent pairs.
|
|
"""
|
|
if config.distributed_debug.value:
|
|
lines = ["\nDISTRIBUTED_DEBUG_BEGIN"]
|
|
try:
|
|
lines.append(f"{pairs[0][0]}: {pairs[0][1]}")
|
|
for label, value in pairs[1:]:
|
|
lines.append(f" {label}: {value}")
|
|
except Exception as e:
|
|
lines.append("DISTRIBUTED_DEBUG logging failed!")
|
|
lines.append(f"{e}")
|
|
lines.append("DISTRIBUTED_DEBUG_END")
|
|
logger.warning("\n".join(lines))
|
|
|
|
|
|
def stable_unique(it: Iterable[T]) -> Iterable[T]:
|
|
"""Returns unique elements from `it` in the order of occurrence.
|
|
|
|
The elements must be hashable.
|
|
"""
|
|
return dict.fromkeys(it).keys()
|
|
|
|
|
|
class OrderedSet(Generic[T]):
|
|
elts_set: set[T]
|
|
elts_list: list[T]
|
|
|
|
def __init__(self):
|
|
self.elts_set = set()
|
|
self.elts_list = []
|
|
|
|
def add(self, elt: T) -> None:
|
|
if elt not in self.elts_set:
|
|
self.elts_set.add(elt)
|
|
self.elts_list.append(elt)
|
|
|
|
def update(self, elts: Seq[T]) -> None:
|
|
for e in elts:
|
|
self.add(e)
|
|
|
|
def __iter__(self) -> Iterator[T]:
|
|
return iter(self.elts_list)
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.elts_list)
|
|
|
|
def __contains__(self, elt: T) -> bool:
|
|
return elt in self.elts_set
|
|
|
|
|
|
class HashableWrapper:
|
|
x: Any
|
|
hash: int | None
|
|
def __init__(self, x):
|
|
self.x = x
|
|
try: self.hash = hash(x)
|
|
except: self.hash = None
|
|
def __hash__(self):
|
|
return self.hash if self.hash is not None else id(self.x)
|
|
def __eq__(self, other):
|
|
if not isinstance(other, HashableWrapper):
|
|
return False
|
|
return self.x == other.x if self.hash is not None else self.x is other.x
|
|
|
|
|
|
def _original_func(f: Callable) -> Callable:
|
|
if isinstance(f, property):
|
|
fget = cast(property, f).fget
|
|
assert fget is not None
|
|
return fget
|
|
elif isinstance(f, functools.cached_property):
|
|
return f.func
|
|
return f
|
|
|
|
|
|
def set_module(module: str) -> Callable[[T], T]:
|
|
def wrapper(func: T) -> T:
|
|
if module is not None:
|
|
func.__module__ = module
|
|
return func
|
|
return wrapper
|
|
|
|
|
|
def use_cpp_class(cpp_cls: type[Any]) -> Callable[[type[T]], type[T]]:
|
|
"""A decorator replacing a Python class with its C++ version at runtime."""
|
|
|
|
def wrapper(cls):
|
|
if cpp_cls is None:
|
|
return cls
|
|
|
|
exclude_methods = {'__module__', '__dict__', '__doc__'}
|
|
|
|
for attr_name, attr in cls.__dict__.items():
|
|
if attr_name not in exclude_methods:
|
|
if not hasattr(_original_func(attr), "_use_cpp"):
|
|
setattr(cpp_cls, attr_name, attr)
|
|
|
|
cpp_cls.__doc__ = cls.__doc__
|
|
return cpp_cls
|
|
|
|
return wrapper
|
|
|
|
def use_cpp_method(is_enabled: bool = True) -> Callable[[T], T]:
|
|
"""A decorator excluding methods from the set that are forwarded to C++ class."""
|
|
if not isinstance(is_enabled, bool):
|
|
raise TypeError("``is_enabled`` must be a bool")
|
|
def decorator(f):
|
|
if is_enabled:
|
|
original_func = _original_func(f)
|
|
original_func._use_cpp = True # pyrefly: ignore[missing-attribute]
|
|
return f
|
|
return decorator
|
|
|
|
|
|
class StrictABCMeta(abc.ABCMeta):
|
|
"""A variant of `abc.ABCMeta` which does not allow virtual subclasses.
|
|
|
|
Virtual subclasses support require `abc.ABCMeta` to roundtrip through
|
|
pure Python when doing instance/subclass checking. This if fine for ABCs
|
|
which need virtual subclasses, but is wasteful for the ones which don't.
|
|
"""
|
|
def register(cls, subclass):
|
|
del subclass # Unused.
|
|
raise NotImplementedError(f"{cls} does not support virtual subclasses")
|
|
|
|
__instancecheck__ = type.__instancecheck__
|
|
__subclasscheck__ = type.__subclasscheck__
|
|
|
|
|
|
class StrictABC(metaclass=StrictABCMeta):
|
|
__slots__ = ()
|
|
|
|
|
|
test_event_listener: Callable | None = None
|
|
|
|
def test_event(name: str, *args) -> None:
|
|
if not test_event_listener:
|
|
return
|
|
test_event_listener(name, *args)
|
|
|
|
Mutex = jaxlib_utils.Mutex
|
|
|
|
|
|
def pprint_bytes(num_bytes: int | float) -> str:
|
|
prefixes = ("", "K", "M", "G", "T")
|
|
if num_bytes <= 0:
|
|
return "0.00B"
|
|
exponent = min(math.floor(math.log(num_bytes, 1000)), len(prefixes) - 1)
|
|
scaled_value = num_bytes / (1000**exponent)
|
|
return f"{scaled_value:.2f}{prefixes[exponent]}B"
|
|
|
|
install_failure_signal_handler = jaxlib_utils.install_failure_signal_handler
|