This commit is contained in:
2026-05-06 19:47:31 +07:00
parent 94d8682530
commit 12dbb7731b
9963 changed files with 2747894 additions and 0 deletions
@@ -0,0 +1,29 @@
"""Main init function for opt_einsum."""
from opt_einsum import blas, helpers, path_random, paths
from opt_einsum._version import __version__
from opt_einsum.contract import contract, contract_expression, contract_path
from opt_einsum.parser import get_symbol
from opt_einsum.path_random import RandomGreedy
from opt_einsum.paths import BranchBound, DynamicProgramming
from opt_einsum.sharing import shared_intermediates
__all__ = [
"__version__",
"blas",
"helpers",
"path_random",
"paths",
"contract",
"contract_expression",
"contract_path",
"get_symbol",
"RandomGreedy",
"BranchBound",
"DynamicProgramming",
"shared_intermediates",
]
paths.register_path_fn("random-greedy", path_random.random_greedy)
paths.register_path_fn("random-greedy-128", path_random.random_greedy_128)
@@ -0,0 +1,16 @@
# file generated by setuptools_scm
# don't change, don't track in version control
TYPE_CHECKING = False
if TYPE_CHECKING:
from typing import Tuple, Union
VERSION_TUPLE = Tuple[Union[int, str], ...]
else:
VERSION_TUPLE = object
version: str
__version__: str
__version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE
__version__ = version = '3.4.0'
__version_tuple__ = version_tuple = (3, 4, 0)
@@ -0,0 +1,28 @@
"""Compute backends for opt_einsum."""
# Backends
from opt_einsum.backends.cupy import to_cupy
from opt_einsum.backends.dispatch import (
build_expression,
evaluate_constants,
get_func,
has_backend,
has_einsum,
has_tensordot,
)
from opt_einsum.backends.tensorflow import to_tensorflow
from opt_einsum.backends.theano import to_theano
from opt_einsum.backends.torch import to_torch
__all__ = [
"get_func",
"has_einsum",
"has_tensordot",
"build_expression",
"evaluate_constants",
"has_backend",
"to_tensorflow",
"to_theano",
"to_cupy",
"to_torch",
]
@@ -0,0 +1,32 @@
"""Required functions for optimized contractions of numpy arrays using cupy."""
from opt_einsum.helpers import has_array_interface
from opt_einsum.sharing import to_backend_cache_wrap
__all__ = ["to_cupy", "build_expression", "evaluate_constants"]
@to_backend_cache_wrap
def to_cupy(array): # pragma: no cover
import cupy
if has_array_interface(array):
return cupy.asarray(array)
return array
def build_expression(_, expr): # pragma: no cover
"""Build a cupy function based on ``arrays`` and ``expr``."""
def cupy_contract(*arrays):
return expr._contract([to_cupy(x) for x in arrays], backend="cupy").get()
return cupy_contract
def evaluate_constants(const_arrays, expr): # pragma: no cover
"""Convert constant arguments to cupy arrays, and perform any possible
constant contractions.
"""
return expr(*[to_cupy(x) for x in const_arrays], backend="cupy", evaluate_constants=True)
@@ -0,0 +1,156 @@
"""Handles dispatching array operations to the correct backend library, as well
as converting arrays to backend formats and then potentially storing them as
constants.
"""
import importlib
from typing import Any, Dict, Tuple
from opt_einsum.backends import cupy as _cupy
from opt_einsum.backends import jax as _jax
from opt_einsum.backends import object_arrays
from opt_einsum.backends import tensorflow as _tensorflow
from opt_einsum.backends import theano as _theano
from opt_einsum.backends import torch as _torch
__all__ = [
"get_func",
"has_einsum",
"has_tensordot",
"build_expression",
"evaluate_constants",
"has_backend",
]
# known non top-level imports
_aliases = {
"dask": "dask.array",
"theano": "theano.tensor",
"torch": "opt_einsum.backends.torch",
"jax": "jax.numpy",
"jaxlib": "jax.numpy",
"autograd": "autograd.numpy",
"mars": "mars.tensor",
}
def _import_func(func: str, backend: str, default: Any = None) -> Any:
"""Try and import ``{backend}.{func}``.
If library is installed and func is found, return the func;
otherwise if default is provided, return default;
otherwise raise an error.
"""
try:
lib = importlib.import_module(_aliases.get(backend, backend))
return getattr(lib, func) if default is None else getattr(lib, func, default)
except AttributeError:
error_msg = (
"{} doesn't seem to provide the function {} - see "
"https://optimized-einsum.readthedocs.io/en/latest/backends.html "
"for details on which functions are required for which contractions."
)
raise AttributeError(error_msg.format(backend, func))
# manually cache functions as python2 doesn't support functools.lru_cache
# other libs will be added to this if needed, but pre-populate with numpy
_cached_funcs: Dict[Tuple[str, str], Any] = {
("einsum", "object"): object_arrays.object_einsum,
}
try:
import numpy as np # type: ignore
_cached_funcs[("tensordot", "numpy")] = np.tensordot
_cached_funcs[("transpose", "numpy")] = np.transpose
_cached_funcs[("einsum", "numpy")] = np.einsum
# also pre-populate with the arbitrary object backend
_cached_funcs[("tensordot", "object")] = np.tensordot
_cached_funcs[("transpose", "object")] = np.transpose
except ModuleNotFoundError:
pass
def get_func(func: str, backend: str = "numpy", default: Any = None) -> Any:
"""Return ``{backend}.{func}``, e.g. ``numpy.einsum``,
or a default func if provided. Cache result.
"""
try:
return _cached_funcs[func, backend]
except KeyError:
fn = _import_func(func, backend, default)
_cached_funcs[func, backend] = fn
return fn
# mark libs with einsum, else try to use tensordot/transpose as much as possible
_has_einsum: Dict[str, bool] = {}
def has_einsum(backend: str) -> bool:
"""Check if ``{backend}.einsum`` exists, cache result for performance."""
try:
return _has_einsum[backend]
except KeyError:
try:
get_func("einsum", backend)
_has_einsum[backend] = True
except AttributeError:
_has_einsum[backend] = False
return _has_einsum[backend]
_has_tensordot: Dict[str, bool] = {}
def has_tensordot(backend: str) -> bool:
"""Check if ``{backend}.tensordot`` exists, cache result for performance."""
try:
return _has_tensordot[backend]
except KeyError:
try:
get_func("tensordot", backend)
_has_tensordot[backend] = True
except AttributeError:
_has_tensordot[backend] = False
return _has_tensordot[backend]
# Dispatch to correct expression backend
# these are the backends which support explicit to-and-from numpy conversion
CONVERT_BACKENDS = {
"tensorflow": _tensorflow.build_expression,
"theano": _theano.build_expression,
"cupy": _cupy.build_expression,
"torch": _torch.build_expression,
"jax": _jax.build_expression,
}
EVAL_CONSTS_BACKENDS = {
"tensorflow": _tensorflow.evaluate_constants,
"theano": _theano.evaluate_constants,
"cupy": _cupy.evaluate_constants,
"torch": _torch.evaluate_constants,
"jax": _jax.evaluate_constants,
}
def build_expression(backend, arrays, expr):
"""Build an expression, based on ``expr`` and initial arrays ``arrays``,
that evaluates using backend ``backend``.
"""
return CONVERT_BACKENDS[backend](arrays, expr)
def evaluate_constants(backend, arrays, expr):
"""Convert constant arrays to the correct backend, and perform as much of
the contraction of ``expr`` with these as possible.
"""
return EVAL_CONSTS_BACKENDS[backend](arrays, expr)
def has_backend(backend: str) -> bool:
"""Checks if the backend is known."""
return backend.lower() in CONVERT_BACKENDS
@@ -0,0 +1,45 @@
"""Required functions for optimized contractions of numpy arrays using jax."""
from opt_einsum.sharing import to_backend_cache_wrap
__all__ = ["build_expression", "evaluate_constants"]
_JAX = None
def _get_jax_and_to_jax():
global _JAX
if _JAX is None:
import jax # type: ignore
@to_backend_cache_wrap
@jax.jit
def to_jax(x):
return x
_JAX = jax, to_jax
return _JAX
def build_expression(_, expr): # pragma: no cover
"""Build a jax function based on ``arrays`` and ``expr``."""
jax, _ = _get_jax_and_to_jax()
jax_expr = jax.jit(expr._contract)
def jax_contract(*arrays):
import numpy as np # type: ignore
return np.asarray(jax_expr(arrays))
return jax_contract
def evaluate_constants(const_arrays, expr): # pragma: no cover
"""Convert constant arguments to jax arrays, and perform any possible
constant contractions.
"""
jax, to_jax = _get_jax_and_to_jax()
return expr(*[to_jax(x) for x in const_arrays], backend="jax", evaluate_constants=True)
@@ -0,0 +1,59 @@
"""Functions for performing contractions with array elements which are objects."""
import functools
import operator
from opt_einsum.typing import ArrayType
def object_einsum(eq: str, *arrays: ArrayType) -> ArrayType:
"""A ``einsum`` implementation for ``numpy`` arrays with object dtype.
The loop is performed in python, meaning the objects themselves need
only to implement ``__mul__`` and ``__add__`` for the contraction to be
computed. This may be useful when, for example, computing expressions of
tensors with symbolic elements, but note it will be very slow when compared
to ``numpy.einsum`` and numeric data types!
Parameters
----------
eq : str
The contraction string, should specify output.
arrays : sequence of arrays
These can be any indexable arrays as long as addition and
multiplication is defined on the elements.
Returns:
-------
out : numpy.ndarray
The output tensor, with ``dtype=object``.
"""
import numpy as np # type: ignore
# when called by ``opt_einsum`` we will always be given a full eq
lhs, output = eq.split("->")
inputs = lhs.split(",")
sizes = {}
for term, array in zip(inputs, arrays):
for k, d in zip(term, array.shape):
sizes[k] = d
out_size = tuple(sizes[k] for k in output)
out = np.empty(out_size, dtype=object)
inner = tuple(k for k in sizes if k not in output)
inner_size = tuple(sizes[k] for k in inner)
for coo_o in np.ndindex(*out_size):
coord = dict(zip(output, coo_o))
def gen_inner_sum():
for coo_i in np.ndindex(*inner_size):
coord.update(dict(zip(inner, coo_i)))
locs = (tuple(coord[k] for k in term) for term in inputs)
elements = (array[loc] for array, loc in zip(arrays, locs))
yield functools.reduce(operator.mul, elements)
out[coo_o] = functools.reduce(operator.add, gen_inner_sum())
return out
@@ -0,0 +1,123 @@
"""Required functions for optimized contractions of numpy arrays using tensorflow."""
from opt_einsum.helpers import has_array_interface
from opt_einsum.sharing import to_backend_cache_wrap
__all__ = ["to_tensorflow", "build_expression", "evaluate_constants"]
_CACHED_TF_DEVICE = None
def _get_tensorflow_and_device():
global _CACHED_TF_DEVICE
if _CACHED_TF_DEVICE is None:
import tensorflow as tf # type: ignore
try:
eager = tf.executing_eagerly()
except AttributeError:
try:
eager = tf.contrib.eager.in_eager_mode()
except AttributeError:
eager = False
device = tf.test.gpu_device_name()
if not device:
device = "cpu"
_CACHED_TF_DEVICE = tf, device, eager
return _CACHED_TF_DEVICE
@to_backend_cache_wrap(constants=True)
def to_tensorflow(array, constant=False):
"""Convert a numpy array to a ``tensorflow.placeholder`` instance."""
tf, device, eager = _get_tensorflow_and_device()
if eager:
if has_array_interface(array):
with tf.device(device):
return tf.convert_to_tensor(array)
return array
if has_array_interface(array):
if constant:
return tf.convert_to_tensor(array)
return tf.placeholder(array.dtype, array.shape)
return array
# Standard graph mode
def build_expression_graph(arrays, expr):
"""Build a tensorflow function based on ``arrays`` and ``expr``."""
tf, _, _ = _get_tensorflow_and_device()
placeholders = [to_tensorflow(array) for array in arrays]
graph = expr._contract(placeholders, backend="tensorflow")
def tensorflow_contract(*arrays):
session = tf.get_default_session()
# only want to feed placeholders - constant tensors already have values
feed_dict = {p: a for p, a in zip(placeholders, arrays) if p.op.type == "Placeholder"}
return session.run(graph, feed_dict=feed_dict)
return tensorflow_contract
def evaluate_constants_graph(const_arrays, expr):
"""Convert constant arguments to tensorflow constants, and perform any
possible constant contractions. Requires evaluating a tensorflow graph.
"""
tf, _, _ = _get_tensorflow_and_device()
# compute the partial graph of new inputs
const_arrays = [to_tensorflow(x, constant=True) for x in const_arrays]
new_ops, new_contraction_list = expr(*const_arrays, backend="tensorflow", evaluate_constants=True)
# evaluate the new inputs and convert back to tensorflow, maintaining None as non-consts
session = tf.get_default_session()
new_consts = iter(session.run([x for x in new_ops if x is not None]))
new_ops = [None if x is None else to_tensorflow(next(new_consts), constant=True) for x in new_ops]
return new_ops, new_contraction_list
# Eager execution mode
def build_expression_eager(_, expr):
"""Build a eager tensorflow function based on ``arrays`` and ``expr``."""
def tensorflow_eager_contract(*arrays):
return expr._contract([to_tensorflow(x) for x in arrays], backend="tensorflow").numpy()
return tensorflow_eager_contract
def evaluate_constants_eager(const_arrays, expr):
"""Convert constant arguments to tensorflow_eager arrays, and perform any
possible constant contractions.
"""
return expr(*[to_tensorflow(x) for x in const_arrays], backend="tensorflow", evaluate_constants=True)
# Dispatch to eager or graph mode
def build_expression(arrays, expr):
_, _, eager = _get_tensorflow_and_device()
fn = build_expression_eager if eager else build_expression_graph
return fn(arrays, expr)
def evaluate_constants(const_arrays, expr):
_, _, eager = _get_tensorflow_and_device()
fn = evaluate_constants_eager if eager else evaluate_constants_graph
return fn(const_arrays, expr)
@@ -0,0 +1,48 @@
"""Required functions for optimized contractions of numpy arrays using theano."""
from opt_einsum.helpers import has_array_interface
from opt_einsum.sharing import to_backend_cache_wrap
__all__ = ["to_theano", "build_expression", "evaluate_constants"]
@to_backend_cache_wrap(constants=True)
def to_theano(array, constant=False):
"""Convert a numpy array to ``theano.tensor.TensorType`` instance."""
import theano # type: ignore
if has_array_interface(array):
if constant:
return theano.tensor.constant(array)
return theano.tensor.TensorType(dtype=array.dtype, broadcastable=[False] * len(array.shape))()
return array
def build_expression(arrays, expr):
"""Build a theano function based on ``arrays`` and ``expr``."""
import theano
in_vars = [to_theano(array) for array in arrays]
out_var = expr._contract(in_vars, backend="theano")
# don't supply constants to graph
graph_ins = [x for x in in_vars if not isinstance(x, theano.tensor.TensorConstant)]
graph = theano.function(graph_ins, out_var)
def theano_contract(*arrays):
return graph(*[x for x in arrays if not isinstance(x, theano.tensor.TensorConstant)])
return theano_contract
def evaluate_constants(const_arrays, expr):
# compute the partial graph of new inputs
const_arrays = [to_theano(x, constant=True) for x in const_arrays]
new_ops, new_contraction_list = expr(*const_arrays, backend="theano", evaluate_constants=True)
# evaluate the new inputs and convert to theano shared tensors
new_ops = [None if x is None else to_theano(x.eval(), constant=True) for x in new_ops]
return new_ops, new_contraction_list
@@ -0,0 +1,130 @@
"""Required functions for optimized contractions of numpy arrays using pytorch."""
from opt_einsum.helpers import has_array_interface
from opt_einsum.parser import convert_to_valid_einsum_chars
from opt_einsum.sharing import to_backend_cache_wrap
__all__ = [
"transpose",
"einsum",
"tensordot",
"to_torch",
"build_expression",
"evaluate_constants",
]
_TORCH_DEVICE = None
_TORCH_HAS_TENSORDOT = None
_torch_symbols_base = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
def _get_torch_and_device():
global _TORCH_DEVICE
global _TORCH_HAS_TENSORDOT
if _TORCH_DEVICE is None:
import torch # type: ignore
device = "cuda" if torch.cuda.is_available() else "cpu"
_TORCH_DEVICE = torch, device
_TORCH_HAS_TENSORDOT = hasattr(torch, "tensordot")
return _TORCH_DEVICE
def transpose(a, axes):
"""Normal torch transpose is only valid for 2D matrices."""
return a.permute(*axes)
def einsum(equation, *operands, **kwargs):
"""Variadic version of torch.einsum to match numpy api."""
# rename symbols to support PyTorch 0.4.1 and earlier,
# which allow only symbols a-z.
equation = convert_to_valid_einsum_chars(equation)
torch, _ = _get_torch_and_device()
return torch.einsum(equation, operands)
def tensordot(x, y, axes=2):
"""Simple translation of tensordot syntax to einsum."""
torch, _ = _get_torch_and_device()
if _TORCH_HAS_TENSORDOT:
return torch.tensordot(x, y, dims=axes)
xnd = x.ndimension()
ynd = y.ndimension()
# convert int argument to (list[int], list[int])
if isinstance(axes, int):
axes = range(xnd - axes, xnd), range(axes)
# convert (int, int) to (list[int], list[int])
if isinstance(axes[0], int):
axes = (axes[0],), axes[1]
if isinstance(axes[1], int):
axes = axes[0], (axes[1],)
# initialize empty indices
x_ix = [None] * xnd
y_ix = [None] * ynd
out_ix = []
# fill in repeated indices
available_ix = iter(_torch_symbols_base)
for ax1, ax2 in zip(*axes):
repeat = next(available_ix)
x_ix[ax1] = repeat
y_ix[ax2] = repeat
# fill in the rest, and maintain output order
for i in range(xnd):
if x_ix[i] is None:
leave = next(available_ix)
x_ix[i] = leave
out_ix.append(leave)
for i in range(ynd):
if y_ix[i] is None:
leave = next(available_ix)
y_ix[i] = leave
out_ix.append(leave)
# form full string and contract!
einsum_str = "{},{}->{}".format(*map("".join, (x_ix, y_ix, out_ix)))
return einsum(einsum_str, x, y)
@to_backend_cache_wrap
def to_torch(array):
torch, device = _get_torch_and_device()
if has_array_interface(array):
return torch.from_numpy(array).to(device)
return array
def build_expression(_, expr): # pragma: no cover
"""Build a torch function based on ``arrays`` and ``expr``."""
def torch_contract(*arrays):
torch_arrays = [to_torch(x) for x in arrays]
torch_out = expr._contract(torch_arrays, backend="torch")
if torch_out.device.type == "cpu":
return torch_out.numpy()
return torch_out.cpu().numpy()
return torch_contract
def evaluate_constants(const_arrays, expr):
"""Convert constant arguments to torch, and perform any possible constant
contractions.
"""
const_arrays = [to_torch(x) for x in const_arrays]
return expr(*const_arrays, backend="torch", evaluate_constants=True)
@@ -0,0 +1,122 @@
"""Determines if a contraction can use BLAS or not."""
from typing import List, Sequence, Tuple, Union
from opt_einsum.typing import ArrayIndexType
__all__ = ["can_blas"]
def can_blas(
inputs: List[str],
result: str,
idx_removed: ArrayIndexType,
shapes: Union[Sequence[Tuple[int]], None] = None,
) -> Union[str, bool]:
"""Checks if we can use a BLAS call.
Parameters
----------
inputs : list of str
Specifies the subscripts for summation.
result : str
Resulting summation.
idx_removed : set
Indices that are removed in the summation
shapes : sequence of tuple[int], optional
If given, check also that none of the indices are broadcast dimensions.
Returns:
-------
type : str or bool
The type of BLAS call to be used or False if none.
Notes:
-----
We assume several operations are not efficient such as a transposed
DDOT, therefore 'ijk,jki->' should prefer einsum. These return the blas
type appended with "/EINSUM" to differentiate when they can still be done
with tensordot if required, e.g. when a backend has no einsum.
Examples:
--------
>>> can_blas(['ij', 'jk'], 'ik', set('j'))
'GEMM'
>>> can_blas(['ijj', 'jk'], 'ik', set('j'))
False
>>> can_blas(['ab', 'cd'], 'abcd', set())
'OUTER/EINSUM'
>>> # looks like GEMM but actually 'j' is broadcast:
>>> can_blas(['ij', 'jk'], 'ik', set('j'), shapes=[(4, 1), (5, 6)])
False
"""
# Can only do two
if len(inputs) != 2:
return False
input_left, input_right = inputs
for c in set(input_left + input_right):
# can't deal with repeated indices on same input or more than 2 total
nl, nr = input_left.count(c), input_right.count(c)
if (nl > 1) or (nr > 1) or (nl + nr > 2):
return False
# can't do implicit summation or dimension collapse e.g.
# "ab,bc->c" (implicitly sum over 'a')
# "ab,ca->ca" (take diagonal of 'a')
if nl + nr - 1 == int(c in result):
return False
# check for broadcast indices e.g:
# "ij,jk->ik" (but one of the 'j' dimensions is broadcast up)
if shapes is not None:
for c in idx_removed:
if shapes[0][input_left.find(c)] != shapes[1][input_right.find(c)]:
return False
# Prefer einsum if not removing indices
# (N.B. tensordot outer faster for large arrays?)
if len(idx_removed) == 0:
return "OUTER/EINSUM"
# Build a few temporaries
sets = [set(x) for x in inputs]
keep_left = sets[0] - idx_removed
keep_right = sets[1] - idx_removed
rs = len(idx_removed)
# DDOT
if inputs[0] == inputs[1]:
return "DOT"
# DDOT does not make sense if you have to transpose - prefer einsum
elif sets[0] == sets[1]:
return "DOT/EINSUM"
# GEMM no transpose
if input_left[-rs:] == input_right[:rs]:
return "GEMM"
# GEMM transpose both
elif input_left[:rs] == input_right[-rs:]:
return "GEMM"
# GEMM transpose right
elif input_left[-rs:] == input_right[-rs:]:
return "GEMM"
# GEMM transpose left
elif input_left[:rs] == input_right[:rs]:
return "GEMM"
# Einsum is faster than vectordot if we have to copy
elif (len(keep_left) == 0) or (len(keep_right) == 0):
return "GEMV/EINSUM"
# Conventional tensordot
else:
return "TDOT"
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,151 @@
"""Contains helper functions for opt_einsum testing scripts."""
from typing import Any, Collection, Dict, FrozenSet, Iterable, List, Tuple, overload
from opt_einsum.typing import ArrayIndexType, ArrayType
__all__ = ["compute_size_by_dict", "find_contraction", "flop_count"]
_valid_chars = "abcdefghijklmopqABC"
_sizes = [2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3, 2, 5, 7, 4, 3, 2, 3, 4]
_default_dim_dict = dict(zip(_valid_chars, _sizes))
@overload
def compute_size_by_dict(indices: Iterable[int], idx_dict: List[int]) -> int: ...
@overload
def compute_size_by_dict(indices: Collection[str], idx_dict: Dict[str, int]) -> int: ...
def compute_size_by_dict(indices: Any, idx_dict: Any) -> int:
"""Computes the product of the elements in indices based on the dictionary
idx_dict.
Parameters
----------
indices : iterable
Indices to base the product on.
idx_dict : dictionary
Dictionary of index _sizes
Returns:
-------
ret : int
The resulting product.
Examples:
--------
>>> compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5})
90
"""
ret = 1
for i in indices: # lgtm [py/iteration-string-and-sequence]
ret *= idx_dict[i]
return ret
def find_contraction(
positions: Collection[int],
input_sets: List[ArrayIndexType],
output_set: ArrayIndexType,
) -> Tuple[FrozenSet[str], List[ArrayIndexType], ArrayIndexType, ArrayIndexType]:
"""Finds the contraction for a given set of input and output sets.
Parameters
----------
positions : iterable
Integer positions of terms used in the contraction.
input_sets : list
List of sets that represent the lhs side of the einsum subscript
output_set : set
Set that represents the rhs side of the overall einsum subscript
Returns:
-------
new_result : set
The indices of the resulting contraction
remaining : list
List of sets that have not been contracted, the new set is appended to
the end of this list
idx_removed : set
Indices removed from the entire contraction
idx_contraction : set
The indices used in the current contraction
Examples:
--------
# A simple dot product test case
>>> pos = (0, 1)
>>> isets = [set('ab'), set('bc')]
>>> oset = set('ac')
>>> find_contraction(pos, isets, oset)
({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'})
# A more complex case with additional terms in the contraction
>>> pos = (0, 2)
>>> isets = [set('abd'), set('ac'), set('bdc')]
>>> oset = set('ac')
>>> find_contraction(pos, isets, oset)
({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'})
"""
remaining = list(input_sets)
inputs = (remaining.pop(i) for i in sorted(positions, reverse=True))
idx_contract = frozenset.union(*inputs)
idx_remain = output_set.union(*remaining)
new_result = idx_remain & idx_contract
idx_removed = idx_contract - new_result
remaining.append(new_result)
return new_result, remaining, idx_removed, idx_contract
def flop_count(
idx_contraction: Collection[str],
inner: bool,
num_terms: int,
size_dictionary: Dict[str, int],
) -> int:
"""Computes the number of FLOPS in the contraction.
Parameters
----------
idx_contraction : iterable
The indices involved in the contraction
inner : bool
Does this contraction require an inner product?
num_terms : int
The number of terms in a contraction
size_dictionary : dict
The size of each of the indices in idx_contraction
Returns:
-------
flop_count : int
The total number of FLOPS required for the contraction.
Examples:
--------
>>> flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5})
30
>>> flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5})
60
"""
overall_size = compute_size_by_dict(idx_contraction, size_dictionary)
op_factor = max(1, num_terms - 1)
if inner:
op_factor += 1
return overall_size * op_factor
def has_array_interface(array: ArrayType) -> ArrayType:
if hasattr(array, "__array_interface__"):
return True
else:
return False
@@ -0,0 +1,415 @@
"""A functionally equivalent parser of the numpy.einsum input parser."""
import itertools
from typing import Any, Dict, Iterator, List, Sequence, Tuple
from opt_einsum.typing import ArrayType, TensorShapeType
__all__ = [
"is_valid_einsum_char",
"has_valid_einsum_chars_only",
"get_symbol",
"get_shape",
"gen_unused_symbols",
"convert_to_valid_einsum_chars",
"alpha_canonicalize",
"find_output_str",
"find_output_shape",
"possibly_convert_to_numpy",
"parse_einsum_input",
]
_einsum_symbols_base = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
def is_valid_einsum_char(x: str) -> bool:
"""Check if the character ``x`` is valid for numpy einsum.
**Examples:**
```python
is_valid_einsum_char("a")
#> True
is_valid_einsum_char("Ǵ")
#> False
```
"""
return (x in _einsum_symbols_base) or (x in ",->.")
def has_valid_einsum_chars_only(einsum_str: str) -> bool:
"""Check if ``einsum_str`` contains only valid characters for numpy einsum.
**Examples:**
```python
has_valid_einsum_chars_only("abAZ")
#> True
has_valid_einsum_chars_only("Över")
#> False
```
"""
return all(map(is_valid_einsum_char, einsum_str))
def get_symbol(i: int) -> str:
"""Get the symbol corresponding to int ``i`` - runs through the usual 52
letters before resorting to unicode characters, starting at ``chr(192)`` and skipping surrogates.
**Examples:**
```python
get_symbol(2)
#> 'c'
get_symbol(200)
#> 'Ŕ'
get_symbol(20000)
#> ''
```
"""
if i < 52:
return _einsum_symbols_base[i]
elif i >= 55296:
# Skip chr(57343) - chr(55296) as surrogates
return chr(i + 2048)
else:
return chr(i + 140)
def gen_unused_symbols(used: str, n: int) -> Iterator[str]:
"""Generate ``n`` symbols that are not already in ``used``.
**Examples:**
```python
list(oe.parser.gen_unused_symbols("abd", 2))
#> ['c', 'e']
```
"""
i = cnt = 0
while cnt < n:
s = get_symbol(i)
i += 1
if s in used:
continue
yield s
cnt += 1
def convert_to_valid_einsum_chars(einsum_str: str) -> str:
"""Convert the str ``einsum_str`` to contain only the alphabetic characters
valid for numpy einsum. If there are too many symbols, let the backend
throw an error.
Examples:
--------
>>> oe.parser.convert_to_valid_einsum_chars("Ĥěļļö")
'cbdda'
"""
symbols = sorted(set(einsum_str) - set(",->"))
replacer = {x: get_symbol(i) for i, x in enumerate(symbols)}
return "".join(replacer.get(x, x) for x in einsum_str)
def alpha_canonicalize(equation: str) -> str:
"""Alpha convert an equation in an order-independent canonical way.
Examples:
--------
>>> oe.parser.alpha_canonicalize("dcba")
'abcd'
>>> oe.parser.alpha_canonicalize("Ĥěļļö")
'abccd'
"""
rename: Dict[str, str] = {}
for name in equation:
if name in ".,->":
continue
if name not in rename:
rename[name] = get_symbol(len(rename))
return "".join(rename.get(x, x) for x in equation)
def find_output_str(subscripts: str) -> str:
"""Find the output string for the inputs ``subscripts`` under canonical einstein summation rules.
That is, repeated indices are summed over by default.
Examples:
--------
>>> oe.parser.find_output_str("ab,bc")
'ac'
>>> oe.parser.find_output_str("a,b")
'ab'
>>> oe.parser.find_output_str("a,a,b,b")
''
"""
tmp_subscripts = subscripts.replace(",", "")
return "".join(s for s in sorted(set(tmp_subscripts)) if tmp_subscripts.count(s) == 1)
def find_output_shape(inputs: List[str], shapes: List[TensorShapeType], output: str) -> TensorShapeType:
"""Find the output shape for given inputs, shapes and output string, taking
into account broadcasting.
Examples:
--------
>>> oe.parser.find_output_shape(["ab", "bc"], [(2, 3), (3, 4)], "ac")
(2, 4)
# Broadcasting is accounted for
>>> oe.parser.find_output_shape(["a", "a"], [(4, ), (1, )], "a")
(4,)
"""
return tuple(max(shape[loc] for shape, loc in zip(shapes, [x.find(c) for x in inputs]) if loc >= 0) for c in output)
_BaseTypes = (bool, int, float, complex, str, bytes)
def get_shape(x: Any) -> TensorShapeType:
"""Get the shape of the array-like object `x`. If `x` is not array-like, raise an error.
Array-like objects are those that have a `shape` attribute, are sequences of BaseTypes, or are BaseTypes.
BaseTypes are defined as `bool`, `int`, `float`, `complex`, `str`, and `bytes`.
"""
if hasattr(x, "shape"):
return x.shape
elif isinstance(x, _BaseTypes):
return ()
elif isinstance(x, Sequence):
shape = []
while isinstance(x, Sequence) and not isinstance(x, _BaseTypes):
shape.append(len(x))
x = x[0]
return tuple(shape)
else:
raise ValueError(f"Cannot determine the shape of {x}, can only determine the shape of array-like objects.")
def possibly_convert_to_numpy(x: Any) -> Any:
"""Convert things without a 'shape' to ndarrays, but leave everything else.
Examples:
--------
>>> oe.parser.possibly_convert_to_numpy(5)
array(5)
>>> oe.parser.possibly_convert_to_numpy([5, 3])
array([5, 3])
>>> oe.parser.possibly_convert_to_numpy(np.array([5, 3]))
array([5, 3])
# Any class with a shape is passed through
>>> class Shape:
... def __init__(self, shape):
... self.shape = shape
...
>>> myshape = Shape((5, 5))
>>> oe.parser.possibly_convert_to_numpy(myshape)
<__main__.Shape object at 0x10f850710>
"""
if not hasattr(x, "shape"):
try:
import numpy as np # type: ignore
except ModuleNotFoundError:
raise ModuleNotFoundError(
"numpy is required to convert non-array objects to arrays. This function will be deprecated in the future."
)
return np.asanyarray(x)
else:
return x
def convert_subscripts(old_sub: List[Any], symbol_map: Dict[Any, Any]) -> str:
"""Convert user custom subscripts list to subscript string according to `symbol_map`.
Examples:
--------
>>> oe.parser.convert_subscripts(['abc', 'def'], {'abc':'a', 'def':'b'})
'ab'
>>> oe.parser.convert_subscripts([Ellipsis, object], {object:'a'})
'...a'
"""
new_sub = ""
for s in old_sub:
if s is Ellipsis:
new_sub += "..."
else:
# no need to try/except here because symbol_map has already been checked
new_sub += symbol_map[s]
return new_sub
def convert_interleaved_input(operands: Sequence[Any]) -> Tuple[str, Tuple[Any, ...]]:
"""Convert 'interleaved' input to standard einsum input."""
tmp_operands = list(operands)
operand_list = []
subscript_list = []
for _ in range(len(operands) // 2):
operand_list.append(tmp_operands.pop(0))
subscript_list.append(tmp_operands.pop(0))
output_list = tmp_operands[-1] if len(tmp_operands) else None
# build a map from user symbols to single-character symbols based on `get_symbol`
# The map retains the intrinsic order of user symbols
try:
# collect all user symbols
symbol_set = set(itertools.chain.from_iterable(subscript_list))
# remove Ellipsis because it can not be compared with other objects
symbol_set.discard(Ellipsis)
# build the map based on sorted user symbols, retaining the order we lost in the `set`
symbol_map = {symbol: get_symbol(idx) for idx, symbol in enumerate(sorted(symbol_set))}
except TypeError: # unhashable or uncomparable object
raise TypeError(
"For this input type lists must contain either Ellipsis "
"or hashable and comparable object (e.g. int, str)."
)
subscripts = ",".join(convert_subscripts(sub, symbol_map) for sub in subscript_list)
if output_list is not None:
subscripts += "->"
subscripts += convert_subscripts(output_list, symbol_map)
return subscripts, tuple(operand_list)
def parse_einsum_input(operands: Any, shapes: bool = False) -> Tuple[str, str, List[ArrayType]]:
"""A reproduction of einsum c side einsum parsing in python.
Parameters:
operands: Intakes the same inputs as `contract_path`, but NOT the keyword args. The only
supported keyword argument is:
shapes: Whether ``parse_einsum_input`` should assume arrays (the default) or
array shapes have been supplied.
Returns:
input_strings: Parsed input strings
output_string: Parsed output string
operands: The operands to use in the numpy contraction
Examples:
The operand list is simplified to reduce printing:
```python
>>> a = np.random.rand(4, 4)
>>> b = np.random.rand(4, 4, 4)
>>> parse_einsum_input(('...a,...a->...', a, b))
('za,xza', 'xz', [a, b])
>>> parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0]))
('za,xza', 'xz', [a, b])
```
"""
if len(operands) == 0:
raise ValueError("No input operands")
if isinstance(operands[0], str):
subscripts = operands[0].replace(" ", "")
if shapes:
if any(hasattr(o, "shape") for o in operands[1:]):
raise ValueError(
"shapes is set to True but given at least one operand looks like an array"
" (at least one operand has a shape attribute). "
)
operands = operands[1:]
else:
subscripts, operands = convert_interleaved_input(operands)
if shapes:
operand_shapes = operands
else:
operand_shapes = [get_shape(o) for o in operands]
# Check for proper "->"
if ("-" in subscripts) or (">" in subscripts):
invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1)
if invalid or (subscripts.count("->") != 1):
raise ValueError("Subscripts can only contain one '->'.")
# Parse ellipses
if "." in subscripts:
used = subscripts.replace(".", "").replace(",", "").replace("->", "")
ellipse_inds = "".join(gen_unused_symbols(used, max(len(x) for x in operand_shapes)))
longest = 0
# Do we have an output to account for?
if "->" in subscripts:
input_tmp, output_sub = subscripts.split("->")
split_subscripts = input_tmp.split(",")
out_sub = True
else:
split_subscripts = subscripts.split(",")
out_sub = False
for num, sub in enumerate(split_subscripts):
if "." in sub:
if (sub.count(".") != 3) or (sub.count("...") != 1):
raise ValueError("Invalid Ellipses.")
# Take into account numerical values
if operand_shapes[num] == ():
ellipse_count = 0
else:
ellipse_count = max(len(operand_shapes[num]), 1) - (len(sub) - 3)
if ellipse_count > longest:
longest = ellipse_count
if ellipse_count < 0:
raise ValueError("Ellipses lengths do not match.")
elif ellipse_count == 0:
split_subscripts[num] = sub.replace("...", "")
else:
split_subscripts[num] = sub.replace("...", ellipse_inds[-ellipse_count:])
subscripts = ",".join(split_subscripts)
# Figure out output ellipses
if longest == 0:
out_ellipse = ""
else:
out_ellipse = ellipse_inds[-longest:]
if out_sub:
subscripts += "->" + output_sub.replace("...", out_ellipse)
else:
# Special care for outputless ellipses
output_subscript = find_output_str(subscripts)
normal_inds = "".join(sorted(set(output_subscript) - set(out_ellipse)))
subscripts += "->" + out_ellipse + normal_inds
# Build output string if does not exist
if "->" in subscripts:
input_subscripts, output_subscript = subscripts.split("->")
else:
input_subscripts, output_subscript = subscripts, find_output_str(subscripts)
# Make sure output subscripts are unique and in the input
for char in output_subscript:
if output_subscript.count(char) != 1:
raise ValueError(f"Output character '{char}' appeared more than once in the output.")
if char not in input_subscripts:
raise ValueError(f"Output character '{char}' did not appear in the input")
# Make sure number operands is equivalent to the number of terms
if len(input_subscripts.split(",")) != len(operands):
raise ValueError(
f"Number of einsum subscripts, {len(input_subscripts.split(','))}, must be equal to the "
f"number of operands, {len(operands)}."
)
return input_subscripts, output_subscript, operands
@@ -0,0 +1,398 @@
"""Support for random optimizers, including the random-greedy path."""
import functools
import heapq
import math
import time
from collections import deque
from decimal import Decimal
from random import choices as random_choices
from random import seed as random_seed
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union
from opt_einsum import helpers, paths
from opt_einsum.typing import ArrayIndexType, ArrayType, PathType
__all__ = ["RandomGreedy", "random_greedy", "random_greedy_128"]
class RandomOptimizer(paths.PathOptimizer):
"""Base class for running any random path finder that benefits
from repeated calling, possibly in a parallel fashion. Custom random
optimizers should subclass this, and the `setup` method should be
implemented with the following signature:
```python
def setup(self, inputs, output, size_dict):
# custom preparation here ...
return trial_fn, trial_args
```
Where `trial_fn` itself should have the signature::
```python
def trial_fn(r, *trial_args):
# custom computation of path here
return ssa_path, cost, size
```
Where `r` is the run number and could for example be used to seed a
random number generator. See `RandomGreedy` for an example.
Parameters:
max_repeats: The maximum number of repeat trials to have.
max_time: The maximum amount of time to run the algorithm for.
minimize: Whether to favour paths that minimize the total estimated flop-count or
the size of the largest intermediate created.
parallel: Whether to parallelize the random trials, by default `False`. If
`True`, use a `concurrent.futures.ProcessPoolExecutor` with the same
number of processes as cores. If an integer is specified, use that many
processes instead. Finally, you can supply a custom executor-pool which
should have an API matching that of the python 3 standard library
module `concurrent.futures`. Namely, a `submit` method that returns
`Future` objects, themselves with `result` and `cancel` methods.
pre_dispatch: If running in parallel, how many jobs to pre-dispatch so as to avoid
submitting all jobs at once. Should also be more than twice the number
of workers to avoid under-subscription. Default: 128.
Attributes:
path: The best path found so far.
costs: The list of each trial's costs found so far.
sizes: The list of each trial's largest intermediate size so far.
"""
def __init__(
self,
max_repeats: int = 32,
max_time: Optional[float] = None,
minimize: str = "flops",
parallel: Union[bool, Decimal, int] = False,
pre_dispatch: int = 128,
):
if minimize not in ("flops", "size"):
raise ValueError("`minimize` should be one of {'flops', 'size'}.")
self.max_repeats = max_repeats
self.max_time = max_time
self.minimize = minimize
self.better = paths.get_better_fn(minimize)
self._parallel: Union[bool, Decimal, int] = False
self.parallel = parallel
self.pre_dispatch = pre_dispatch
self.costs: List[int] = []
self.sizes: List[int] = []
self.best: Dict[str, Any] = {"flops": float("inf"), "size": float("inf")}
self._repeats_start = 0
self._executor: Any
self._futures: Any
@property
def path(self) -> PathType:
"""The best path found so far."""
return paths.ssa_to_linear(self.best["ssa_path"])
@property
def parallel(self) -> Union[bool, Decimal, int]:
return self._parallel
@parallel.setter
def parallel(self, parallel: Union[bool, Decimal, int]) -> None:
# shutdown any previous executor if we are managing it
if getattr(self, "_managing_executor", False):
self._executor.shutdown()
self._parallel = parallel
self._managing_executor = False
if parallel is False:
self._executor = None
return
if parallel is True:
from concurrent.futures import ProcessPoolExecutor
self._executor = ProcessPoolExecutor()
self._managing_executor = True
return
if isinstance(parallel, (int, Decimal)):
from concurrent.futures import ProcessPoolExecutor
self._executor = ProcessPoolExecutor(int(parallel))
self._managing_executor = True
return
# assume a pool-executor has been supplied
self._executor = parallel
def _gen_results_parallel(self, repeats: Iterable[int], trial_fn: Any, args: Any) -> Generator[Any, None, None]:
"""Lazily generate results from an executor without submitting all jobs at once."""
self._futures = deque()
# the idea here is to submit at least ``pre_dispatch`` jobs *before* we
# yield any results, then do both in tandem, before draining the queue
for r in repeats:
if len(self._futures) < self.pre_dispatch:
self._futures.append(self._executor.submit(trial_fn, r, *args))
continue
yield self._futures.popleft().result()
while self._futures:
yield self._futures.popleft().result()
def _cancel_futures(self) -> None:
if self._executor is not None:
for f in self._futures:
f.cancel()
def setup(
self,
inputs: List[ArrayIndexType],
output: ArrayIndexType,
size_dict: Dict[str, int],
) -> Tuple[Any, Any]:
raise NotImplementedError
def __call__(
self,
inputs: List[ArrayIndexType],
output: ArrayIndexType,
size_dict: Dict[str, int],
memory_limit: Optional[int] = None,
) -> PathType:
self._check_args_against_first_call(inputs, output, size_dict)
# start a timer?
if self.max_time is not None:
t0 = time.time()
trial_fn, trial_args = self.setup(inputs, output, size_dict)
r_start = self._repeats_start + len(self.costs)
r_stop = r_start + self.max_repeats
repeats = range(r_start, r_stop)
# create the trials lazily
if self._executor is not None:
trials = self._gen_results_parallel(repeats, trial_fn, trial_args)
else:
trials = (trial_fn(r, *trial_args) for r in repeats)
# assess the trials
for ssa_path, cost, size in trials:
# keep track of all costs and sizes
self.costs.append(cost)
self.sizes.append(size)
# check if we have found a new best
found_new_best = self.better(cost, size, self.best["flops"], self.best["size"])
if found_new_best:
self.best["flops"] = cost
self.best["size"] = size
self.best["ssa_path"] = ssa_path
# check if we have run out of time
if (self.max_time is not None) and (time.time() > t0 + self.max_time):
break
self._cancel_futures()
return self.path
def __del__(self):
# if we created the parallel pool-executor, shut it down
if getattr(self, "_managing_executor", False):
self._executor.shutdown()
def thermal_chooser(queue, remaining, nbranch=8, temperature=1, rel_temperature=True):
"""A contraction 'chooser' that weights possible contractions using a
Boltzmann distribution. Explicitly, given costs `c_i` (with `c_0` the
smallest), the relative weights, `w_i`, are computed as:
$$w_i = exp( -(c_i - c_0) / temperature)$$
Additionally, if `rel_temperature` is set, scale `temperature` by
`abs(c_0)` to account for likely fluctuating cost magnitudes during the
course of a contraction.
Parameters:
queue: The heapified list of candidate contractions.
remaining: Mapping of remaining inputs' indices to the ssa id.
temperature: When choosing a possible contraction, its relative probability will be
proportional to `exp(-cost / temperature)`. Thus the larger
`temperature` is, the further random paths will stray from the normal
'greedy' path. Conversely, if set to zero, only paths with exactly the
same cost as the best at each step will be explored.
rel_temperature: Whether to normalize the `temperature` at each step to the scale of
the best cost. This is generally beneficial as the magnitude of costs
can vary significantly throughout a contraction.
nbranch: How many potential paths to calculate probability for and choose from at each step.
Returns:
cost
k1
k2
k3
"""
n = 0
choices = []
while queue and n < nbranch:
cost, k1, k2, k12 = heapq.heappop(queue)
if k1 not in remaining or k2 not in remaining:
continue # candidate is obsolete
choices.append((cost, k1, k2, k12))
n += 1
if n == 0:
return None
if n == 1:
return choices[0]
costs = [choice[0][0] for choice in choices]
cmin = costs[0]
# adjust by the overall scale to account for fluctuating absolute costs
if rel_temperature:
temperature *= max(1, abs(cmin))
# compute relative probability for each potential contraction
if temperature == 0.0:
energies = [1 if c == cmin else 0 for c in costs]
else:
# shift by cmin for numerical reasons
energies = [math.exp(-(c - cmin) / temperature) for c in costs]
# randomly choose a contraction based on energies
(chosen,) = random_choices(range(n), weights=energies)
cost, k1, k2, k12 = choices.pop(chosen)
# put the other choice back in the heap
for other in choices:
heapq.heappush(queue, other)
return cost, k1, k2, k12
def ssa_path_compute_cost(
ssa_path: PathType,
inputs: List[ArrayIndexType],
output: ArrayIndexType,
size_dict: Dict[str, int],
) -> Tuple[int, int]:
"""Compute the flops and max size of an ssa path."""
inputs = list(map(frozenset, inputs))
output = frozenset(output)
remaining = set(range(len(inputs)))
total_cost = 0
max_size = 0
for i, j in ssa_path:
k12, flops12 = paths.calc_k12_flops(inputs, output, remaining, i, j, size_dict) # type: ignore
remaining.discard(i)
remaining.discard(j)
remaining.add(len(inputs))
inputs.append(k12)
total_cost += flops12
max_size = max(max_size, helpers.compute_size_by_dict(k12, size_dict))
return total_cost, max_size
def _trial_greedy_ssa_path_and_cost(
r: int,
inputs: List[ArrayIndexType],
output: ArrayIndexType,
size_dict: Dict[str, int],
choose_fn: Any,
cost_fn: Any,
) -> Tuple[PathType, int, int]:
"""A single, repeatable, greedy trial run. **Returns:** ``ssa_path`` and cost."""
if r == 0:
# always start with the standard greedy approach
choose_fn = None
random_seed(r)
ssa_path = paths.ssa_greedy_optimize(inputs, output, size_dict, choose_fn, cost_fn)
cost, size = ssa_path_compute_cost(ssa_path, inputs, output, size_dict)
return ssa_path, cost, size
class RandomGreedy(RandomOptimizer):
def __init__(
self,
cost_fn: str = "memory-removed-jitter",
temperature: float = 1.0,
rel_temperature: bool = True,
nbranch: int = 8,
**kwargs: Any,
):
"""Parameters:
cost_fn: A function that returns a heuristic 'cost' of a potential contraction
with which to sort candidates. Should have signature
`cost_fn(size12, size1, size2, k12, k1, k2)`.
temperature: When choosing a possible contraction, its relative probability will be
proportional to `exp(-cost / temperature)`. Thus the larger
`temperature` is, the further random paths will stray from the normal
'greedy' path. Conversely, if set to zero, only paths with exactly the
same cost as the best at each step will be explored.
rel_temperature: Whether to normalize the ``temperature`` at each step to the scale of
the best cost. This is generally beneficial as the magnitude of costs
can vary significantly throughout a contraction. If False, the
algorithm will end up branching when the absolute cost is low, but
stick to the 'greedy' path when the cost is high - this can also be
beneficial.
nbranch: How many potential paths to calculate probability for and choose from at each step.
kwargs: Supplied to RandomOptimizer.
"""
self.cost_fn = cost_fn
self.temperature = temperature
self.rel_temperature = rel_temperature
self.nbranch = nbranch
super().__init__(**kwargs)
@property
def choose_fn(self) -> Any:
"""The function that chooses which contraction to take - make this a
property so that ``temperature`` and ``nbranch`` etc. can be updated
between runs.
"""
if self.nbranch == 1:
return None
return functools.partial(
thermal_chooser,
temperature=self.temperature,
nbranch=self.nbranch,
rel_temperature=self.rel_temperature,
)
def setup(
self,
inputs: List[ArrayIndexType],
output: ArrayIndexType,
size_dict: Dict[str, int],
) -> Tuple[Any, Any]:
fn = _trial_greedy_ssa_path_and_cost
args = (inputs, output, size_dict, self.choose_fn, self.cost_fn)
return fn, args
def random_greedy(
inputs: List[ArrayIndexType],
output: ArrayIndexType,
idx_dict: Dict[str, int],
memory_limit: Optional[int] = None,
**optimizer_kwargs: Any,
) -> ArrayType:
"""A simple wrapper around the `RandomGreedy` optimizer."""
optimizer = RandomGreedy(**optimizer_kwargs)
return optimizer(inputs, output, idx_dict, memory_limit)
random_greedy_128 = functools.partial(random_greedy, max_repeats=128)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,216 @@
"""A module for sharing intermediates between contractions.
Copyright (c) 2018 Uber Technologies
"""
import contextlib
import functools
import numbers
import threading
from collections import Counter, defaultdict
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
from typing import Counter as CounterType
from opt_einsum.parser import alpha_canonicalize, parse_einsum_input
from opt_einsum.typing import ArrayType
CacheKeyType = Union[Tuple[str, str, int, Tuple[int, ...]], Tuple[str, int]]
CacheType = Dict[CacheKeyType, ArrayType]
__all__ = [
"currently_sharing",
"get_sharing_cache",
"shared_intermediates",
"count_cached_ops",
"transpose_cache_wrap",
"einsum_cache_wrap",
"to_backend_cache_wrap",
]
_SHARING_STACK: Dict[int, List[CacheType]] = defaultdict(list)
def currently_sharing() -> bool:
"""Check if we are currently sharing a cache -- thread specific."""
return threading.get_ident() in _SHARING_STACK
def get_sharing_cache() -> CacheType:
"""Return the most recent sharing cache -- thread specific."""
return _SHARING_STACK[threading.get_ident()][-1]
def _add_sharing_cache(cache: CacheType) -> Any:
_SHARING_STACK[threading.get_ident()].append(cache)
def _remove_sharing_cache() -> None:
tid = threading.get_ident()
_SHARING_STACK[tid].pop()
if not _SHARING_STACK[tid]:
del _SHARING_STACK[tid]
@contextlib.contextmanager
def shared_intermediates(
cache: Optional[CacheType] = None,
) -> Generator[CacheType, None, None]:
"""Context in which contract intermediate results are shared.
Note that intermediate computations will not be garbage collected until
1. this context exits, and
2. the yielded cache is garbage collected (if it was captured).
**Parameters:**
- **cache** - *(dict)* If specified, a user-stored dict in which intermediate results will be stored. This can be used to interleave sharing contexts.
**Returns:**
- **cache** - *(dict)* A dictionary in which sharing results are stored. If ignored,
sharing results will be garbage collected when this context is
exited. This dict can be passed to another context to resume
sharing.
"""
if cache is None:
cache = {}
_add_sharing_cache(cache)
try:
yield cache
finally:
_remove_sharing_cache()
def count_cached_ops(cache: CacheType) -> CounterType[str]:
"""Returns a counter of the types of each op in the cache.
This is useful for profiling to increase sharing.
"""
return Counter(key[0] for key in cache.keys())
def _save_tensors(*tensors: ArrayType) -> None:
"""Save tensors in the cache to prevent their ids from being recycled.
This is needed to prevent false cache lookups.
"""
cache = get_sharing_cache()
for tensor in tensors:
cache["tensor", id(tensor)] = tensor
def _memoize(key: CacheKeyType, fn: Any, *args: Any, **kwargs: Any) -> ArrayType:
"""Memoize ``fn(*args, **kwargs)`` using the given ``key``.
Results will be stored in the innermost ``cache`` yielded by
:func:`shared_intermediates`.
"""
cache = get_sharing_cache()
if key in cache:
return cache[key]
result = fn(*args, **kwargs)
cache[key] = result
return result
def transpose_cache_wrap(transpose: Any) -> Any:
"""Decorates a ``transpose()`` implementation to be memoized inside a
:func:`shared_intermediates` context.
"""
@functools.wraps(transpose)
def cached_transpose(a, axes, backend="numpy"):
if not currently_sharing():
return transpose(a, axes, backend=backend)
# hash by axes
_save_tensors(a)
axes = tuple(axes)
key = "transpose", backend, id(a), axes
return _memoize(key, transpose, a, axes, backend=backend)
return cached_transpose
def tensordot_cache_wrap(tensordot: Any) -> Any:
"""Decorates a ``tensordot()`` implementation to be memoized inside a
:func:`shared_intermediates` context.
"""
@functools.wraps(tensordot)
def cached_tensordot(x, y, axes=2, backend="numpy"):
if not currently_sharing():
return tensordot(x, y, axes, backend=backend)
# hash based on the (axes_x,axes_y) form of axes
_save_tensors(x, y)
if isinstance(axes, numbers.Number):
axes = (
list(range(len(x.shape)))[len(x.shape) - axes :],
list(range(len(y.shape)))[:axes],
)
axes = tuple(axes[0]), tuple(axes[1])
key = "tensordot", backend, id(x), id(y), axes
return _memoize(key, tensordot, x, y, axes, backend=backend)
return cached_tensordot
def einsum_cache_wrap(einsum: Any) -> Any:
"""Decorates an ``einsum()`` implementation to be memoized inside a
:func:`shared_intermediates` context.
"""
@functools.wraps(einsum)
def cached_einsum(*args, **kwargs):
if not currently_sharing():
return einsum(*args, **kwargs)
# hash modulo commutativity by computing a canonical ordering and names
backend = kwargs.pop("backend", "numpy")
equation = args[0]
inputs, output, operands = parse_einsum_input(args)
inputs = inputs.split(",")
_save_tensors(*operands)
# Build canonical key
canonical = sorted(zip(inputs, map(id, operands)), key=lambda x: x[1])
canonical_ids = tuple(id_ for _, id_ in canonical)
canonical_inputs = ",".join(input_ for input_, _ in canonical)
canonical_equation = alpha_canonicalize(canonical_inputs + "->" + output)
key = "einsum", backend, canonical_equation, canonical_ids
return _memoize(key, einsum, equation, *operands, backend=backend)
return cached_einsum
def to_backend_cache_wrap(to_backend: Any = None, constants: Any = False) -> Any:
"""Decorates an ``to_backend()`` implementation to be memoized inside a
:func:`shared_intermediates` context (e.g. ``to_cupy``, ``to_torch``).
"""
# manage the case that decorator is called with args
if to_backend is None:
return functools.partial(to_backend_cache_wrap, constants=constants)
if constants:
@functools.wraps(to_backend)
def cached_to_backend(array, constant=False):
if not currently_sharing():
return to_backend(array, constant=constant)
# hash by id
key = to_backend.__name__, id(array), constant
return _memoize(key, to_backend, array, constant=constant)
else:
@functools.wraps(to_backend)
def cached_to_backend(array):
if not currently_sharing():
return to_backend(array)
# hash by id
key = to_backend.__name__, id(array)
return _memoize(key, to_backend, array)
return cached_to_backend
@@ -0,0 +1,224 @@
"""Testing routines for opt_einsum."""
import random
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, overload
import pytest
from opt_einsum.parser import get_symbol
from opt_einsum.typing import ArrayType, PathType, TensorShapeType
_valid_chars = "abcdefghijklmopqABC"
_sizes = [2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3, 2, 5, 7, 4, 3, 2, 3, 4]
_default_dim_dict = dict(zip(_valid_chars, _sizes))
def build_shapes(string: str, dimension_dict: Optional[Dict[str, int]] = None) -> Tuple[TensorShapeType, ...]:
"""Builds random tensor shapes for testing.
Parameters:
string: List of tensor strings to build
dimension_dict: Dictionary of index sizes, defaults to indices size of 2-7
Returns:
The resulting shapes.
Examples:
```python
>>> shapes = build_shapes('abbc', {'a': 2, 'b':3, 'c':5})
>>> shapes
[(2, 3), (3, 3, 5), (5,)]
```
"""
if dimension_dict is None:
dimension_dict = _default_dim_dict
shapes = []
terms = string.split("->")[0].split(",")
for term in terms:
dims = [dimension_dict[x] for x in term]
shapes.append(tuple(dims))
return tuple(shapes)
def build_views(
string: str, dimension_dict: Optional[Dict[str, int]] = None, array_function: Optional[Any] = None
) -> Tuple[ArrayType]:
"""Builds random numpy arrays for testing.
Parameters:
string: List of tensor strings to build
dimension_dict: Dictionary of index _sizes
array_function: Function to build the arrays, defaults to np.random.rand
Returns:
The resulting views.
Examples:
```python
>>> view = build_views('abbc', {'a': 2, 'b':3, 'c':5})
>>> view[0].shape
(2, 3, 3, 5)
```
"""
if array_function is None:
np = pytest.importorskip("numpy")
array_function = np.random.rand
views = []
for shape in build_shapes(string, dimension_dict=dimension_dict):
if shape:
views.append(array_function(*shape))
else:
views.append(random.random())
return tuple(views)
@overload
def rand_equation(
n: int,
regularity: int,
n_out: int = ...,
d_min: int = ...,
d_max: int = ...,
seed: Optional[int] = ...,
global_dim: bool = ...,
*,
return_size_dict: Literal[True],
) -> Tuple[str, PathType, Dict[str, int]]: ...
@overload
def rand_equation(
n: int,
regularity: int,
n_out: int = ...,
d_min: int = ...,
d_max: int = ...,
seed: Optional[int] = ...,
global_dim: bool = ...,
return_size_dict: Literal[False] = ...,
) -> Tuple[str, PathType]: ...
def rand_equation(
n: int,
regularity: int,
n_out: int = 0,
d_min: int = 2,
d_max: int = 9,
seed: Optional[int] = None,
global_dim: bool = False,
return_size_dict: bool = False,
) -> Union[Tuple[str, PathType, Dict[str, int]], Tuple[str, PathType]]:
"""Generate a random contraction and shapes.
Parameters:
n: Number of array arguments.
regularity: 'Regularity' of the contraction graph. This essentially determines how
many indices each tensor shares with others on average.
n_out: Number of output indices (i.e. the number of non-contracted indices).
Defaults to 0, i.e., a contraction resulting in a scalar.
d_min: Minimum dimension size.
d_max: Maximum dimension size.
seed: If not None, seed numpy's random generator with this.
global_dim: Add a global, 'broadcast', dimension to every operand.
return_size_dict: Return the mapping of indices to sizes.
Returns:
eq: The equation string.
shapes: The array shapes.
size_dict: The dict of index sizes, only returned if ``return_size_dict=True``.
Examples:
```python
>>> eq, shapes = rand_equation(n=10, regularity=4, n_out=5, seed=42)
>>> eq
'oyeqn,tmaq,skpo,vg,hxui,n,fwxmr,hitplcj,kudlgfv,rywjsb->cebda'
>>> shapes
[(9, 5, 4, 5, 4),
(4, 4, 8, 5),
(9, 4, 6, 9),
(6, 6),
(6, 9, 7, 8),
(4,),
(9, 3, 9, 4, 9),
(6, 8, 4, 6, 8, 6, 3),
(4, 7, 8, 8, 6, 9, 6),
(9, 5, 3, 3, 9, 5)]
```
"""
np = pytest.importorskip("numpy")
if seed is not None:
np.random.seed(seed)
# total number of indices
num_inds = n * regularity // 2 + n_out
inputs = ["" for _ in range(n)]
output = []
size_dict = {get_symbol(i): np.random.randint(d_min, d_max + 1) for i in range(num_inds)}
# generate a list of indices to place either once or twice
def gen():
for i, ix in enumerate(size_dict):
# generate an outer index
if i < n_out:
output.append(ix)
yield ix
# generate a bond
else:
yield ix
yield ix
# add the indices randomly to the inputs
for i, ix in enumerate(np.random.permutation(list(gen()))):
# make sure all inputs have at least one index
if i < n:
inputs[i] += ix
else:
# don't add any traces on same op
where = np.random.randint(0, n)
while ix in inputs[where]:
where = np.random.randint(0, n)
inputs[where] += ix
# possibly add the same global dim to every arg
if global_dim:
gdim = get_symbol(num_inds)
size_dict[gdim] = np.random.randint(d_min, d_max + 1)
for i in range(n):
inputs[i] += gdim
output += gdim
# randomly transpose the output indices and form equation
output = "".join(np.random.permutation(output)) # type: ignore
eq = "{}->{}".format(",".join(inputs), output)
# make the shapes
shapes = [tuple(size_dict[ix] for ix in op) for op in inputs]
ret = (eq, shapes)
if return_size_dict:
return ret + (size_dict,)
else:
return ret
def build_arrays_from_tuples(path: PathType) -> List[Any]:
"""Build random numpy arrays from a path.
Parameters:
path: The path to build arrays from.
Returns:
The resulting arrays.
"""
np = pytest.importorskip("numpy")
return [np.random.rand(*x) for x in path]
@@ -0,0 +1,462 @@
from typing import Set
import pytest
from opt_einsum import backends, contract, contract_expression, sharing
from opt_einsum.contract import ArrayShaped, infer_backend, parse_backend
from opt_einsum.testing import build_views
try:
# needed so tensorflow doesn't allocate all gpu mem
try:
from tensorflow import ConfigProto # type: ignore
from tensorflow import Session as TFSession
except ImportError:
from tensorflow.compat.v1 import ConfigProto # type: ignore
from tensorflow.compat.v1 import Session as TFSession
_TF_CONFIG = ConfigProto()
_TF_CONFIG.gpu_options.allow_growth = True
except ImportError:
pass
tests = [
"ab,bc->ca",
"abc,bcd,dea",
"abc,def->fedcba",
"abc,bcd,df->fa",
# test 'prefer einsum' ops
"ijk,ikj",
"i,j->ij",
"ijk,k->ij",
"AB,BC->CA",
]
@pytest.mark.parametrize("string", tests)
def test_tensorflow(string: str) -> None:
np = pytest.importorskip("numpy")
pytest.importorskip("tensorflow")
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
opt = np.empty_like(ein)
shps = [v.shape for v in views]
expr = contract_expression(string, *shps, optimize=True)
sess = TFSession(config=_TF_CONFIG)
with sess.as_default():
expr(*views, backend="tensorflow", out=opt)
sess.close()
assert np.allclose(ein, opt)
# test non-conversion mode
tensorflow_views = [backends.to_tensorflow(view) for view in views]
expr(*tensorflow_views)
@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}])
def test_tensorflow_with_constants(constants: Set[int]) -> None:
np = pytest.importorskip("numpy")
tf = pytest.importorskip("tensorflow")
eq = "ij,jk,kl->li"
shapes = (2, 3), (3, 4), (4, 5)
(non_const,) = {0, 1, 2} - constants
ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)]
var = np.random.rand(*shapes[non_const])
res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3)))
expr = contract_expression(eq, *ops, constants=constants)
# check tensorflow
with TFSession(config=_TF_CONFIG).as_default():
res_got = expr(var, backend="tensorflow")
assert all(
array is None or infer_backend(array) == "tensorflow" for array in expr._evaluated_constants["tensorflow"]
)
assert np.allclose(res_exp, res_got)
# check can call with numpy still
res_got2 = expr(var, backend="numpy")
assert np.allclose(res_exp, res_got2)
# check tensorflow call returns tensorflow still
res_got3 = expr(backends.to_tensorflow(var))
assert isinstance(res_got3, tf.Tensor)
@pytest.mark.parametrize("string", tests)
def test_tensorflow_with_sharing(string: str) -> None:
np = pytest.importorskip("numpy")
tf = pytest.importorskip("tensorflow")
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
shps = [v.shape for v in views]
expr = contract_expression(string, *shps, optimize=True)
sess = TFSession(config=_TF_CONFIG)
with sess.as_default(), sharing.shared_intermediates() as cache:
tfl1 = expr(*views, backend="tensorflow")
assert sharing.get_sharing_cache() is cache
cache_sz = len(cache)
assert cache_sz > 0
tfl2 = expr(*views, backend="tensorflow")
assert len(cache) == cache_sz
assert all(isinstance(t, tf.Tensor) for t in cache.values())
assert np.allclose(ein, tfl1)
assert np.allclose(ein, tfl2)
@pytest.mark.parametrize("string", tests)
def test_theano(string: str) -> None:
np = pytest.importorskip("numpy")
theano = pytest.importorskip("theano")
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
shps = [v.shape for v in views]
expr = contract_expression(string, *shps, optimize=True)
opt = expr(*views, backend="theano")
assert np.allclose(ein, opt)
# test non-conversion mode
theano_views = [backends.to_theano(view) for view in views]
theano_opt = expr(*theano_views)
assert isinstance(theano_opt, theano.tensor.TensorVariable)
@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}])
def test_theano_with_constants(constants: Set[int]) -> None:
np = pytest.importorskip("numpy")
theano = pytest.importorskip("theano")
eq = "ij,jk,kl->li"
shapes = (2, 3), (3, 4), (4, 5)
(non_const,) = {0, 1, 2} - constants
ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)]
var = np.random.rand(*shapes[non_const])
res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3)))
expr = contract_expression(eq, *ops, constants=constants)
# check theano
res_got = expr(var, backend="theano")
assert all(array is None or infer_backend(array) == "theano" for array in expr._evaluated_constants["theano"])
assert np.allclose(res_exp, res_got)
# check can call with numpy still
res_got2 = expr(var, backend="numpy")
assert np.allclose(res_exp, res_got2)
# check theano call returns theano still
res_got3 = expr(backends.to_theano(var))
assert isinstance(res_got3, theano.tensor.TensorVariable)
@pytest.mark.parametrize("string", tests)
def test_theano_with_sharing(string: str) -> None:
np = pytest.importorskip("numpy")
theano = pytest.importorskip("theano")
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
shps = [v.shape for v in views]
expr = contract_expression(string, *shps, optimize=True)
with sharing.shared_intermediates() as cache:
thn1 = expr(*views, backend="theano")
assert sharing.get_sharing_cache() is cache
cache_sz = len(cache)
assert cache_sz > 0
thn2 = expr(*views, backend="theano")
assert len(cache) == cache_sz
assert all(isinstance(t, theano.tensor.TensorVariable) for t in cache.values())
assert np.allclose(ein, thn1)
assert np.allclose(ein, thn2)
@pytest.mark.parametrize("string", tests)
def test_cupy(string: str) -> None:
np = pytest.importorskip("numpy") # pragma: no cover
cupy = pytest.importorskip("cupy")
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
shps = [v.shape for v in views]
expr = contract_expression(string, *shps, optimize=True)
opt = expr(*views, backend="cupy")
assert np.allclose(ein, opt)
# test non-conversion mode
cupy_views = [backends.to_cupy(view) for view in views]
cupy_opt = expr(*cupy_views)
assert isinstance(cupy_opt, cupy.ndarray)
assert np.allclose(ein, cupy.asnumpy(cupy_opt))
@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}])
def test_cupy_with_constants(constants: Set[int]) -> None:
np = pytest.importorskip("numpy") # pragma: no cover
cupy = pytest.importorskip("cupy")
eq = "ij,jk,kl->li"
shapes = (2, 3), (3, 4), (4, 5)
(non_const,) = {0, 1, 2} - constants
ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)]
var = np.random.rand(*shapes[non_const])
res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3)))
expr = contract_expression(eq, *ops, constants=constants)
# check cupy
res_got = expr(var, backend="cupy")
# check cupy versions of constants exist
assert all(array is None or infer_backend(array) == "cupy" for array in expr._evaluated_constants["cupy"])
assert np.allclose(res_exp, res_got)
# check can call with numpy still
res_got2 = expr(var, backend="numpy")
assert np.allclose(res_exp, res_got2)
# check cupy call returns cupy still
res_got3 = expr(cupy.asarray(var))
assert isinstance(res_got3, cupy.ndarray)
assert np.allclose(res_exp, res_got3.get())
@pytest.mark.parametrize("string", tests)
def test_jax(string: str) -> None:
np = pytest.importorskip("numpy") # pragma: no cover
pytest.importorskip("jax")
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
shps = [v.shape for v in views]
expr = contract_expression(string, *shps, optimize=True)
opt = expr(*views, backend="jax")
assert np.allclose(ein, opt)
assert isinstance(opt, np.ndarray)
@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}])
def test_jax_with_constants(constants: Set[int]) -> None:
jax = pytest.importorskip("jax")
key = jax.random.PRNGKey(42)
eq = "ij,jk,kl->li"
shapes = (2, 3), (3, 4), (4, 5)
(non_const,) = {0, 1, 2} - constants
ops = [jax.random.uniform(key, shp) if i in constants else shp for i, shp in enumerate(shapes)]
var = jax.random.uniform(key, shapes[non_const])
res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3)))
expr = contract_expression(eq, *ops, constants=constants)
# check jax
res_got = expr(var, backend="jax")
# check jax versions of constants exist
assert all(array is None or infer_backend(array).startswith("jax") for array in expr._evaluated_constants["jax"])
assert jax.numpy.sum(jax.numpy.abs(res_exp - res_got)) < 1e-8
def test_jax_jit_gradient() -> None:
jax = pytest.importorskip("jax")
key = jax.random.PRNGKey(42)
eq = "ij,jk,kl->"
shapes = (2, 3), (3, 4), (4, 2)
views = [jax.random.uniform(key, s) for s in shapes]
expr = contract_expression(eq, *shapes)
x0 = expr(*views)
jit_expr = jax.jit(expr)
x1 = jit_expr(*views).item()
assert x1 == pytest.approx(x0, rel=1e-5)
# jax only takes gradient w.r.t first argument
grad_expr = jax.jit(jax.grad(lambda views: expr(*views)))
view_grads = grad_expr(views)
assert all(v1.shape == v2.shape for v1, v2 in zip(views, view_grads))
# taking a step along the gradient should reduce our 'loss'
new_views = [v - 0.001 * dv for v, dv in zip(views, view_grads)]
x2 = jit_expr(*new_views).item()
assert x2 < x1
def test_autograd_gradient() -> None:
np = pytest.importorskip("numpy")
autograd = pytest.importorskip("autograd")
eq = "ij,jk,kl->"
shapes = (2, 3), (3, 4), (4, 2)
views = [np.random.randn(*s) for s in shapes]
expr = contract_expression(eq, *shapes)
x0 = expr(*views)
# autograd only takes gradient w.r.t first argument
grad_expr = autograd.grad(lambda views: expr(*views))
view_grads = grad_expr(views)
assert all(v1.shape == v2.shape for v1, v2 in zip(views, view_grads))
# taking a step along the gradient should reduce our 'loss'
new_views = [v - 0.001 * dv for v, dv in zip(views, view_grads)]
x1 = expr(*new_views)
assert x1 < x0
@pytest.mark.parametrize("string", tests)
def test_dask(string: str) -> None:
np = pytest.importorskip("numpy")
da = pytest.importorskip("dask.array")
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
shps = [v.shape for v in views]
expr = contract_expression(string, *shps, optimize=True)
# test non-conversion mode
da_views = [da.from_array(x, chunks=(2)) for x in views]
da_opt = expr(*da_views)
# check type is maintained when not using numpy arrays
assert isinstance(da_opt, da.Array)
assert np.allclose(ein, np.array(da_opt))
# try raw contract
da_opt = contract(string, *da_views)
assert isinstance(da_opt, da.Array)
assert np.allclose(ein, np.array(da_opt))
@pytest.mark.parametrize("string", tests)
def test_sparse(string: str) -> None:
np = pytest.importorskip("numpy")
sparse = pytest.importorskip("sparse")
views = build_views(string)
# sparsify views so they don't become dense during contraction
for view in views:
np.random.seed(42)
mask = np.random.choice([False, True], view.shape, True, [0.05, 0.95])
view[mask] = 0
ein = contract(string, *views, optimize=False, use_blas=False)
shps = [v.shape for v in views]
expr = contract_expression(string, *shps, optimize=True)
# test non-conversion mode
sparse_views = [sparse.COO.from_numpy(x) for x in views]
sparse_opt = expr(*sparse_views)
# If the expression returns a float, stop here
if not ein.shape:
assert pytest.approx(ein) == 0.0
return
# check type is maintained when not using numpy arrays
assert isinstance(sparse_opt, sparse.COO)
assert np.allclose(ein, sparse_opt.todense())
# try raw contract
sparse_opt = contract(string, *sparse_views)
assert isinstance(sparse_opt, sparse.COO)
assert np.allclose(ein, sparse_opt.todense())
@pytest.mark.parametrize("string", tests)
def test_torch(string: str) -> None:
torch = pytest.importorskip("torch")
views = build_views(string, array_function=torch.rand)
ein = torch.einsum(string, *views)
shps = [v.shape for v in views]
expr = contract_expression(string, *shps, optimize=True)
opt = expr(*views, backend="torch")
torch.testing.assert_close(ein, opt)
# test non-conversion mode
torch_views = [backends.to_torch(view) for view in views]
torch_opt = expr(*torch_views)
assert isinstance(torch_opt, torch.Tensor)
torch.testing.assert_close(ein, torch_opt)
@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}])
def test_torch_with_constants(constants: Set[int]) -> None:
torch = pytest.importorskip("torch")
eq = "ij,jk,kl->li"
shapes = (2, 3), (3, 4), (4, 5)
(non_const,) = {0, 1, 2} - constants
ops = [torch.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)]
var = torch.rand(*shapes[non_const])
res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3)), backend="torch")
expr = contract_expression(eq, *ops, constants=constants)
# check torch
res_got = expr(var, backend="torch")
assert all(array is None or infer_backend(array) == "torch" for array in expr._evaluated_constants["torch"])
torch.testing.assert_close(res_exp, res_got)
# check can call with numpy still
res_got2 = expr(var, backend="torch")
torch.testing.assert_close(res_exp, res_got2)
# check torch call returns torch still
res_got3 = expr(backends.to_torch(var))
assert isinstance(res_got3, torch.Tensor)
torch.testing.assert_close(res_exp, res_got3)
def test_auto_backend_custom_array_no_tensordot() -> None:
x = ArrayShaped((1, 2, 3))
# Shaped is an array-like object defined by opt_einsum - which has no TDOT
assert infer_backend(x) == "opt_einsum"
assert parse_backend([x], "auto") == "numpy"
assert parse_backend([x], None) == "numpy"
@pytest.mark.parametrize("string", tests)
def test_object_arrays_backend(string: str) -> None:
np = pytest.importorskip("numpy")
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
assert ein.dtype != object
shps = [v.shape for v in views]
expr = contract_expression(string, *shps, optimize=True)
obj_views = [view.astype(object) for view in views]
# try raw contract
obj_opt = contract(string, *obj_views, backend="object")
assert obj_opt.dtype == object
assert np.allclose(ein, obj_opt.astype(float))
# test expression
obj_opt = expr(*obj_views, backend="object")
assert obj_opt.dtype == object
assert np.allclose(ein, obj_opt.astype(float))
@@ -0,0 +1,81 @@
"""
Tests the BLAS capability for the opt_einsum module.
"""
from typing import Any
import pytest
from opt_einsum import blas, contract
blas_tests = [
# DOT
((["k", "k"], "", set("k")), "DOT"), # DDOT
((["ijk", "ijk"], "", set("ijk")), "DOT"), # DDOT
# GEMV?
# GEMM
((["ij", "jk"], "ik", set("j")), "GEMM"), # GEMM N N
((["ijl", "jlk"], "ik", set("jl")), "GEMM"), # GEMM N N Tensor
((["ij", "kj"], "ik", set("j")), "GEMM"), # GEMM N T
((["ijl", "kjl"], "ik", set("jl")), "GEMM"), # GEMM N T Tensor
((["ji", "jk"], "ik", set("j")), "GEMM"), # GEMM T N
((["jli", "jlk"], "ik", set("jl")), "GEMM"), # GEMM T N Tensor
((["ji", "kj"], "ik", set("j")), "GEMM"), # GEMM T T
((["jli", "kjl"], "ik", set("jl")), "GEMM"), # GEMM T T Tensor
# GEMM with final transpose
((["ij", "jk"], "ki", set("j")), "GEMM"), # GEMM N N
((["ijl", "jlk"], "ki", set("jl")), "GEMM"), # GEMM N N Tensor
((["ij", "kj"], "ki", set("j")), "GEMM"), # GEMM N T
((["ijl", "kjl"], "ki", set("jl")), "GEMM"), # GEMM N T Tensor
((["ji", "jk"], "ki", set("j")), "GEMM"), # GEMM T N
((["jli", "jlk"], "ki", set("jl")), "GEMM"), # GEMM T N Tensor
((["ji", "kj"], "ki", set("j")), "GEMM"), # GEMM T T
((["jli", "kjl"], "ki", set("jl")), "GEMM"), # GEMM T T Tensor
# Tensor Dot (requires copy), lets not deal with this for now
((["ilj", "jlk"], "ik", set("jl")), "TDOT"), # FT GEMM N N Tensor
((["ijl", "ljk"], "ik", set("jl")), "TDOT"), # ST GEMM N N Tensor
((["ilj", "kjl"], "ik", set("jl")), "TDOT"), # FT GEMM N T Tensor
((["ijl", "klj"], "ik", set("jl")), "TDOT"), # ST GEMM N T Tensor
((["lji", "jlk"], "ik", set("jl")), "TDOT"), # FT GEMM T N Tensor
((["jli", "ljk"], "ik", set("jl")), "TDOT"), # ST GEMM T N Tensor
((["lji", "jlk"], "ik", set("jl")), "TDOT"), # FT GEMM T N Tensor
((["jli", "ljk"], "ik", set("jl")), "TDOT"), # ST GEMM T N Tensor
# Tensor Dot (requires copy), lets not deal with this for now with transpose
((["ilj", "jlk"], "ik", set("lj")), "TDOT"), # FT GEMM N N Tensor
((["ijl", "ljk"], "ik", set("lj")), "TDOT"), # ST GEMM N N Tensor
((["ilj", "kjl"], "ik", set("lj")), "TDOT"), # FT GEMM N T Tensor
((["ijl", "klj"], "ik", set("lj")), "TDOT"), # ST GEMM N T Tensor
((["lji", "jlk"], "ik", set("lj")), "TDOT"), # FT GEMM T N Tensor
((["jli", "ljk"], "ik", set("lj")), "TDOT"), # ST GEMM T N Tensor
((["lji", "jlk"], "ik", set("lj")), "TDOT"), # FT GEMM T N Tensor
((["jli", "ljk"], "ik", set("lj")), "TDOT"), # ST GEMM T N Tensor
# Other
((["ijk", "ikj"], "", set("ijk")), "DOT/EINSUM"), # Transpose DOT
((["i", "j"], "ij", set()), "OUTER/EINSUM"), # Outer
((["ijk", "ik"], "j", set("ik")), "GEMV/EINSUM"), # Matrix-vector
((["ijj", "jk"], "ik", set("j")), False), # Double index
((["ijk", "j"], "ij", set()), False), # Index sum 1
((["ij", "ij"], "ij", set()), False), # Index sum 2
]
@pytest.mark.parametrize("inp,benchmark", blas_tests)
def test_can_blas(inp: Any, benchmark: bool) -> None:
result = blas.can_blas(*inp)
assert result == benchmark
def test_blas_out() -> None:
np = pytest.importorskip("numpy")
a = np.random.rand(4, 4)
b = np.random.rand(4, 4)
c = np.random.rand(4, 4)
d = np.empty((4, 4))
contract("ij,jk->ik", a, b, out=d)
np.testing.assert_allclose(d, np.dot(a, b))
assert np.allclose(d, np.dot(a, b))
contract("ij,jk,kl->il", a, b, c, out=d)
np.testing.assert_allclose(d, np.dot(a, b).dot(c))
@@ -0,0 +1,279 @@
"""
Tets a series of opt_einsum contraction paths to ensure the results are the same for different paths
"""
from typing import Any, List
import pytest
from opt_einsum import contract, contract_expression, contract_path
from opt_einsum.paths import _PATH_OPTIONS, linear_to_ssa, ssa_to_linear
from opt_einsum.testing import build_views, rand_equation
from opt_einsum.typing import OptimizeKind
# NumPy is required for the majority of this file
np = pytest.importorskip("numpy")
tests = [
# Test scalar-like operations
"a,->a",
"ab,->ab",
",ab,->ab",
",,->",
# Test hadamard-like products
"a,ab,abc->abc",
"a,b,ab->ab",
# Test index-transformations
"ea,fb,gc,hd,abcd->efgh",
"ea,fb,abcd,gc,hd->efgh",
"abcd,ea,fb,gc,hd->efgh",
# Test complex contractions
"acdf,jbje,gihb,hfac,gfac,gifabc,hfac",
"acdf,jbje,gihb,hfac,gfac,gifabc,hfac",
"cd,bdhe,aidb,hgca,gc,hgibcd,hgac",
"abhe,hidj,jgba,hiab,gab",
"bde,cdh,agdb,hica,ibd,hgicd,hiac",
"chd,bde,agbc,hiad,hgc,hgi,hiad",
"chd,bde,agbc,hiad,bdi,cgh,agdb",
"bdhe,acad,hiab,agac,hibd",
# Test collapse
"ab,ab,c->",
"ab,ab,c->c",
"ab,ab,cd,cd->",
"ab,ab,cd,cd->ac",
"ab,ab,cd,cd->cd",
"ab,ab,cd,cd,ef,ef->",
# Test outer prodcuts
"ab,cd,ef->abcdef",
"ab,cd,ef->acdf",
"ab,cd,de->abcde",
"ab,cd,de->be",
"ab,bcd,cd->abcd",
"ab,bcd,cd->abd",
# Random test cases that have previously failed
"eb,cb,fb->cef",
"dd,fb,be,cdb->cef",
"bca,cdb,dbf,afc->",
"dcc,fce,ea,dbf->ab",
"fdf,cdd,ccd,afe->ae",
"abcd,ad",
"ed,fcd,ff,bcf->be",
"baa,dcf,af,cde->be",
"bd,db,eac->ace",
"fff,fae,bef,def->abd",
"efc,dbc,acf,fd->abe",
# Inner products
"ab,ab",
"ab,ba",
"abc,abc",
"abc,bac",
"abc,cba",
# GEMM test cases
"ab,bc",
"ab,cb",
"ba,bc",
"ba,cb",
"abcd,cd",
"abcd,ab",
"abcd,cdef",
"abcd,cdef->feba",
"abcd,efdc",
# Inner than dot
"aab,bc->ac",
"ab,bcc->ac",
"aab,bcc->ac",
"baa,bcc->ac",
"aab,ccb->ac",
# Randomly build test caes
"aab,fa,df,ecc->bde",
"ecb,fef,bad,ed->ac",
"bcf,bbb,fbf,fc->",
"bb,ff,be->e",
"bcb,bb,fc,fff->",
"fbb,dfd,fc,fc->",
"afd,ba,cc,dc->bf",
"adb,bc,fa,cfc->d",
"bbd,bda,fc,db->acf",
"dba,ead,cad->bce",
"aef,fbc,dca->bde",
]
@pytest.mark.parametrize("optimize", (True, False, None))
def test_contract_plain_types(optimize: OptimizeKind) -> None:
expr = "ij,jk,kl->il"
ops = [np.random.rand(2, 2), np.random.rand(2, 2), np.random.rand(2, 2)]
path = contract_path(expr, *ops, optimize=optimize)
assert len(path) == 2
result = contract(expr, *ops, optimize=optimize)
assert result.shape == (2, 2)
@pytest.mark.parametrize("string", tests)
@pytest.mark.parametrize("optimize", _PATH_OPTIONS)
def test_compare(optimize: OptimizeKind, string: str) -> None:
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
opt = contract(string, *views, optimize=optimize, use_blas=False)
assert np.allclose(ein, opt)
@pytest.mark.parametrize("string", tests)
def test_drop_in_replacement(string: str) -> None:
views = build_views(string)
opt = contract(string, *views)
assert np.allclose(opt, np.einsum(string, *views))
@pytest.mark.parametrize("string", tests)
@pytest.mark.parametrize("optimize", _PATH_OPTIONS)
def test_compare_greek(optimize: OptimizeKind, string: str) -> None:
views = build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
# convert to greek
string = "".join(chr(ord(c) + 848) if c not in ",->." else c for c in string)
opt = contract(string, *views, optimize=optimize, use_blas=False)
assert np.allclose(ein, opt)
@pytest.mark.parametrize("string", tests)
@pytest.mark.parametrize("optimize", _PATH_OPTIONS)
def test_compare_blas(optimize: OptimizeKind, string: str) -> None:
views = build_views(string)
ein = contract(string, *views, optimize=False)
opt = contract(string, *views, optimize=optimize)
assert np.allclose(ein, opt)
@pytest.mark.parametrize("string", tests)
@pytest.mark.parametrize("optimize", _PATH_OPTIONS)
def test_compare_blas_greek(optimize: OptimizeKind, string: str) -> None:
views = build_views(string)
ein = contract(string, *views, optimize=False)
# convert to greek
string = "".join(chr(ord(c) + 848) if c not in ",->." else c for c in string)
opt = contract(string, *views, optimize=optimize)
assert np.allclose(ein, opt)
def test_some_non_alphabet_maintains_order() -> None:
# 'c beta a' should automatically go to -> 'a c beta'
string = "c" + chr(ord("b") + 848) + "a"
# but beta will be temporarily replaced with 'b' for which 'cba->abc'
# so check manual output kicks in:
x = np.random.rand(2, 3, 4)
assert np.allclose(contract(string, x), contract("cxa", x))
def test_printing():
string = "bbd,bda,fc,db->acf"
views = build_views(string)
ein = contract_path(string, *views)
assert len(str(ein[1])) == 728
@pytest.mark.parametrize("string", tests)
@pytest.mark.parametrize("optimize", _PATH_OPTIONS)
@pytest.mark.parametrize("use_blas", [False, True])
@pytest.mark.parametrize("out_spec", [False, True])
def test_contract_expressions(string: str, optimize: OptimizeKind, use_blas: bool, out_spec: bool) -> None:
views = build_views(string)
shapes = [view.shape if hasattr(view, "shape") else () for view in views]
expected = contract(string, *views, optimize=False, use_blas=False)
expr = contract_expression(string, *shapes, optimize=optimize, use_blas=use_blas)
if out_spec and ("->" in string) and (string[-2:] != "->"):
(out,) = build_views(string.split("->")[1])
expr(*views, out=out)
else:
out = expr(*views)
assert np.allclose(out, expected)
# check representations
assert string in expr.__repr__()
assert string in expr.__str__()
def test_contract_expression_interleaved_input() -> None:
x, y, z = (np.random.randn(2, 2) for _ in "xyz")
expected = np.einsum(x, [0, 1], y, [1, 2], z, [2, 3], [3, 0])
xshp, yshp, zshp = ((2, 2) for _ in "xyz")
expr = contract_expression(xshp, [0, 1], yshp, [1, 2], zshp, [2, 3], [3, 0])
out = expr(x, y, z)
assert np.allclose(out, expected)
@pytest.mark.parametrize(
"string,constants",
[
("hbc,bdef,cdkj,ji,ikeh,lfo", [1, 2, 3, 4]),
("bdef,cdkj,ji,ikeh,hbc,lfo", [0, 1, 2, 3]),
("hbc,bdef,cdkj,ji,ikeh,lfo", [1, 2, 3, 4]),
("hbc,bdef,cdkj,ji,ikeh,lfo", [1, 2, 3, 4]),
("ijab,acd,bce,df,ef->ji", [1, 2, 3, 4]),
("ab,cd,ad,cb", [1, 3]),
("ab,bc,cd", [0, 1]),
],
)
def test_contract_expression_with_constants(string: str, constants: List[int]) -> None:
views = build_views(string)
expected = contract(string, *views, optimize=False, use_blas=False)
shapes = [view.shape if hasattr(view, "shape") else () for view in views]
expr_args: List[Any] = []
ctrc_args = []
for i, (shape, view) in enumerate(zip(shapes, views)):
if i in constants:
expr_args.append(view)
else:
expr_args.append(shape)
ctrc_args.append(view)
expr = contract_expression(string, *expr_args, constants=constants)
out = expr(*ctrc_args)
assert np.allclose(expected, out)
@pytest.mark.parametrize("optimize", ["greedy", "optimal"])
@pytest.mark.parametrize("n", [4, 5])
@pytest.mark.parametrize("reg", [2, 3])
@pytest.mark.parametrize("n_out", [0, 2, 4])
@pytest.mark.parametrize("global_dim", [False, True])
def test_rand_equation(optimize: OptimizeKind, n: int, reg: int, n_out: int, global_dim: bool) -> None:
eq, _, size_dict = rand_equation(n, reg, n_out, d_min=2, d_max=5, seed=42, return_size_dict=True)
views = build_views(eq, size_dict)
expected = contract(eq, *views, optimize=False)
actual = contract(eq, *views, optimize=optimize)
assert np.allclose(expected, actual)
@pytest.mark.parametrize("equation", tests)
def test_linear_vs_ssa(equation: str) -> None:
views = build_views(equation)
linear_path, _ = contract_path(equation, *views)
ssa_path = linear_to_ssa(linear_path)
linear_path2 = ssa_to_linear(ssa_path)
assert linear_path2 == linear_path
def test_contract_path_supply_shapes() -> None:
eq = "ab,bc,cd"
shps = [(2, 3), (3, 4), (4, 5)]
contract_path(eq, *shps, shapes=True)
@@ -0,0 +1,152 @@
"""
Tets a series of opt_einsum contraction paths to ensure the results are the same for different paths
"""
from typing import Any, Tuple
import pytest
from opt_einsum import contract, contract_expression, contract_path
from opt_einsum.typing import PathType
# NumPy is required for the majority of this file
np = pytest.importorskip("numpy")
def test_contract_expression_checks() -> None:
# check optimize needed
with pytest.raises(ValueError):
contract_expression("ab,bc->ac", (2, 3), (3, 4), optimize=False)
# check sizes are still checked
with pytest.raises(ValueError):
contract_expression("ab,bc->ac", (2, 3), (3, 4), (42, 42))
# check if out given
out = np.empty((2, 4))
with pytest.raises(ValueError):
contract_expression("ab,bc->ac", (2, 3), (3, 4), out=out)
# check still get errors when wrong ranks supplied to expression
expr = contract_expression("ab,bc->ac", (2, 3), (3, 4))
# too few arguments
with pytest.raises(ValueError) as err:
expr(np.random.rand(2, 3))
assert "`ContractExpression` takes exactly 2" in str(err.value)
# too many arguments
with pytest.raises(ValueError) as err:
expr(np.random.rand(2, 3), np.random.rand(2, 3), np.random.rand(2, 3))
assert "`ContractExpression` takes exactly 2" in str(err.value)
# wrong shapes
with pytest.raises(ValueError) as err:
expr(np.random.rand(2, 3, 4), np.random.rand(3, 4))
assert "Internal error while evaluating `ContractExpression`" in str(err.value)
with pytest.raises(ValueError) as err:
expr(np.random.rand(2, 4), np.random.rand(3, 4, 5))
assert "Internal error while evaluating `ContractExpression`" in str(err.value)
with pytest.raises(ValueError) as err:
expr(np.random.rand(2, 3), np.random.rand(3, 4), out=np.random.rand(2, 4, 6))
assert "Internal error while evaluating `ContractExpression`" in str(err.value)
# should only be able to specify out
with pytest.raises(TypeError) as err_type:
expr(np.random.rand(2, 3), np.random.rand(3, 4), order="F") # type: ignore
assert "got an unexpected keyword" in str(err_type.value)
def test_broadcasting_contraction() -> None:
a = np.random.rand(1, 5, 4)
b = np.random.rand(4, 6)
c = np.random.rand(5, 6)
d = np.random.rand(10)
ein_scalar = contract("ijk,kl,jl", a, b, c, optimize=False)
opt_scalar = contract("ijk,kl,jl", a, b, c, optimize=True)
assert np.allclose(ein_scalar, opt_scalar)
result = ein_scalar * d
ein = contract("ijk,kl,jl,i->i", a, b, c, d, optimize=False)
opt = contract("ijk,kl,jl,i->i", a, b, c, d, optimize=True)
assert np.allclose(ein, result)
assert np.allclose(opt, result)
def test_broadcasting_contraction2() -> None:
a = np.random.rand(1, 1, 5, 4)
b = np.random.rand(4, 6)
c = np.random.rand(5, 6)
d = np.random.rand(7, 7)
ein_scalar = contract("abjk,kl,jl", a, b, c, optimize=False)
opt_scalar = contract("abjk,kl,jl", a, b, c, optimize=True)
assert np.allclose(ein_scalar, opt_scalar)
result = ein_scalar * d
ein = contract("abjk,kl,jl,ab->ab", a, b, c, d, optimize=False)
opt = contract("abjk,kl,jl,ab->ab", a, b, c, d, optimize=True)
assert np.allclose(ein, result)
assert np.allclose(opt, result)
def test_broadcasting_contraction3() -> None:
a = np.random.rand(1, 5, 4)
b = np.random.rand(4, 1, 6)
c = np.random.rand(5, 6)
d = np.random.rand(7, 7)
ein = contract("ajk,kbl,jl,ab->ab", a, b, c, d, optimize=False)
opt = contract("ajk,kbl,jl,ab->ab", a, b, c, d, optimize=True)
assert np.allclose(ein, opt)
def test_broadcasting_contraction4() -> None:
a = np.arange(64).reshape(2, 4, 8)
ein = contract("obk,ijk->ioj", a, a, optimize=False)
opt = contract("obk,ijk->ioj", a, a, optimize=True)
assert np.allclose(ein, opt)
def test_can_blas_on_healed_broadcast_dimensions() -> None:
expr = contract_expression("ab,bc,bd->acd", (5, 4), (1, 5), (4, 20))
# first contraction involves broadcasting
assert expr.contraction_list[0][2] == "bc,ab->bca"
assert expr.contraction_list[0][-1] is False
# but then is healed GEMM is usable
assert expr.contraction_list[1][2] == "bca,bd->acd"
assert expr.contraction_list[1][-1] == "GEMM"
def test_pathinfo_for_empty_contraction() -> None:
eq = "->"
arrays = (1.0,)
path: PathType = []
_, info = contract_path(eq, *arrays, optimize=path)
# some info is built lazily, so check repr
assert repr(info)
assert info.largest_intermediate == 1
@pytest.mark.parametrize(
"expression, operands",
[
[",,->", (5, 5.0, 2.0j)],
["ab,->", ([[5, 5], [2.0, 1]], 2.0j)],
["ab,bc->ac", ([[5, 5], [2.0, 1]], [[2.0, 1], [3.0, 4]])],
["ab,->", ([[5, 5], [2.0, 1]], True)],
],
)
def test_contract_with_assumed_shapes(expression: str, operands: Tuple[Any]) -> None:
"""Test that we can contract with assumed shapes, and that the output is correct. This is required as we need to infer intermediate shape sizes."""
benchmark = np.einsum(expression, *operands)
result = contract(expression, *operands, optimize=True)
assert np.allclose(benchmark, result)
@@ -0,0 +1,279 @@
"""
Tests the input parsing for opt_einsum. Duplicates the np.einsum input tests.
"""
from typing import Any, List
import pytest
from opt_einsum import contract, contract_path
from opt_einsum.typing import ArrayType
np = pytest.importorskip("numpy")
def build_views(string: str) -> List[ArrayType]:
"""Builds random numpy arrays for testing by using a fixed size dictionary and an input string."""
chars = "abcdefghij"
sizes_array = np.array([2, 3, 4, 5, 4, 3, 2, 6, 5, 4])
sizes = dict(zip(chars, sizes_array))
views = []
string = string.replace("...", "ij")
terms = string.split("->")[0].split(",")
for term in terms:
dims = [sizes[x] for x in term]
views.append(np.random.rand(*dims))
return views
def test_type_errors() -> None:
# subscripts must be a string
with pytest.raises(TypeError):
contract(0, 0)
# out parameter must be an array
with pytest.raises(TypeError):
contract("", 0, out="test")
# order parameter must be a valid order
# changed in Numpy 1.19, see https://github.com/numpy/numpy/commit/35b0a051c19265f5643f6011ee11e31d30c8bc4c
with pytest.raises((TypeError, ValueError)):
contract("", 0, order="W") # type: ignore
# casting parameter must be a valid casting
with pytest.raises(ValueError):
contract("", 0, casting="blah") # type: ignore
# dtype parameter must be a valid dtype
with pytest.raises(TypeError):
contract("", 0, dtype="bad_data_type")
# other keyword arguments are rejected
with pytest.raises(TypeError):
contract("", 0, bad_arg=0)
# issue 4528 revealed a segfault with this call
with pytest.raises(TypeError):
contract(*(None,) * 63)
# Cannot have two ->
with pytest.raises(ValueError):
contract("->,->", 0, 5)
# Undefined symbol lhs
with pytest.raises(ValueError):
contract("&,a->", 0, 5)
# Undefined symbol rhs
with pytest.raises(ValueError):
contract("a,a->&", 0, 5)
with pytest.raises(ValueError):
contract("a,a->&", 0, 5)
# Catch ellipsis errors
string = "...a->...a"
views = build_views(string)
# Subscript list must contain Ellipsis or (hashable && comparable) object
with pytest.raises(TypeError):
contract(views[0], [Ellipsis, 0], [Ellipsis, ["a"]])
with pytest.raises(TypeError):
contract(views[0], [Ellipsis, {}], [Ellipsis, "a"])
@pytest.mark.parametrize("contract_fn", [contract, contract_path])
def test_value_errors(contract_fn: Any) -> None:
with pytest.raises(ValueError):
contract_fn("")
# subscripts must be a string
with pytest.raises(TypeError):
contract_fn(0, 0)
# invalid subscript character
with pytest.raises(ValueError):
contract_fn("i%...", [0, 0])
with pytest.raises(ValueError):
contract_fn("...j$", [0, 0])
with pytest.raises(ValueError):
contract_fn("i->&", [0, 0])
with pytest.raises(ValueError):
contract_fn("")
# number of operands must match count in subscripts string
with pytest.raises(ValueError):
contract_fn("", 0, 0)
with pytest.raises(ValueError):
contract_fn(",", 0, [0], [0])
with pytest.raises(ValueError):
contract_fn(",", [0])
# can't have more subscripts than dimensions in the operand
with pytest.raises(ValueError):
contract_fn("i", 0)
with pytest.raises(ValueError):
contract_fn("ij", [0, 0])
with pytest.raises(ValueError):
contract_fn("...i", 0)
with pytest.raises(ValueError):
contract_fn("i...j", [0, 0])
with pytest.raises(ValueError):
contract_fn("i...", 0)
with pytest.raises(ValueError):
contract_fn("ij...", [0, 0])
# invalid ellipsis
with pytest.raises(ValueError):
contract_fn("i..", [0, 0])
with pytest.raises(ValueError):
contract_fn(".i...", [0, 0])
with pytest.raises(ValueError):
contract_fn("j->..j", [0, 0])
with pytest.raises(ValueError):
contract_fn("j->.j...", [0, 0])
# invalid subscript character
with pytest.raises(ValueError):
contract_fn("i%...", [0, 0])
with pytest.raises(ValueError):
contract_fn("...j$", [0, 0])
with pytest.raises(ValueError):
contract_fn("i->&", [0, 0])
# output subscripts must appear in input
with pytest.raises(ValueError):
contract_fn("i->ij", [0, 0])
# output subscripts may only be specified once
with pytest.raises(ValueError):
contract_fn("ij->jij", [[0, 0], [0, 0]])
# dimensions much match when being collapsed
with pytest.raises(ValueError):
contract_fn("ii", np.arange(6).reshape(2, 3))
with pytest.raises(ValueError):
contract_fn("ii->i", np.arange(6).reshape(2, 3))
# broadcasting to new dimensions must be enabled explicitly
with pytest.raises(ValueError):
contract_fn("i", np.arange(6).reshape(2, 3))
with pytest.raises(TypeError):
contract_fn("ij->ij", [[0, 1], [0, 1]], bad_kwarg=True)
@pytest.mark.parametrize(
"string",
[
# Ellipse
"...a->...",
"a...->...",
"a...a->...a",
"...,...",
"a,b",
"...a,...b",
],
)
def test_compare(string: str) -> None:
views = build_views(string)
ein = contract(string, *views, optimize=False)
opt = contract(string, *views)
assert np.allclose(ein, opt)
opt = contract(string, *views, optimize="optimal")
assert np.allclose(ein, opt)
def test_ellipse_input1() -> None:
string = "...a->..."
views = build_views(string)
ein = contract(string, *views, optimize=False)
opt = contract(views[0], [Ellipsis, 0], [Ellipsis])
assert np.allclose(ein, opt)
def test_ellipse_input2() -> None:
string = "...a"
views = build_views(string)
ein = contract(string, *views, optimize=False)
opt = contract(views[0], [Ellipsis, 0])
assert np.allclose(ein, opt)
def test_ellipse_input3() -> None:
string = "...a->...a"
views = build_views(string)
ein = contract(string, *views, optimize=False)
opt = contract(views[0], [Ellipsis, 0], [Ellipsis, 0])
assert np.allclose(ein, opt)
def test_ellipse_input4() -> None:
string = "...b,...a->..."
views = build_views(string)
ein = contract(string, *views, optimize=False)
opt = contract(views[0], [Ellipsis, 1], views[1], [Ellipsis, 0], [Ellipsis])
assert np.allclose(ein, opt)
def test_singleton_dimension_broadcast() -> None:
# singleton dimensions broadcast (gh-10343)
p = np.ones((10, 2))
q = np.ones((1, 2))
ein = contract("ij,ij->j", p, q, optimize=False)
opt = contract("ij,ij->j", p, q, optimize=True)
assert np.allclose(ein, opt)
assert np.allclose(opt, [10.0, 10.0])
p = np.ones((1, 5))
q = np.ones((5, 5))
for optimize in (True, False):
res1 = (contract("...ij,...jk->...ik", p, p, optimize=optimize),)
res2 = contract("...ij,...jk->...ik", p, q, optimize=optimize)
assert np.allclose(res1, res2)
assert np.allclose(res2, np.full((1, 5), 5))
def test_large_int_input_format() -> None:
string = "ab,bc,cd"
x, y, z = build_views(string)
string_output = contract(string, x, y, z)
int_output = contract(x, (1000, 1001), y, (1001, 1002), z, (1002, 1003))
assert np.allclose(string_output, int_output)
for i in range(10):
transpose_output = contract(x, (i + 1, i))
assert np.allclose(transpose_output, x.T)
def test_hashable_object_input_format() -> None:
string = "ab,bc,cd"
x, y, z = build_views(string)
string_output = contract(string, x, y, z)
hash_output1 = contract(x, ("left", "bond1"), y, ("bond1", "bond2"), z, ("bond2", "right"))
hash_output2 = contract(
x,
("left", "bond1"),
y,
("bond1", "bond2"),
z,
("bond2", "right"),
("left", "right"),
)
assert np.allclose(string_output, hash_output1)
assert np.allclose(hash_output1, hash_output2)
for i in range(1, 10):
transpose_output = contract(x, ("b" * i, "a" * i))
assert np.allclose(transpose_output, x.T)
@@ -0,0 +1,74 @@
"""
Directly tests various parser utility functions.
"""
from typing import Any, Tuple
import pytest
from opt_einsum.parser import get_shape, get_symbol, parse_einsum_input
from opt_einsum.testing import build_arrays_from_tuples
def test_get_symbol() -> None:
assert get_symbol(2) == "c"
assert get_symbol(200000) == "\U00031540"
# Ensure we skip surrogates '[\uD800-\uDFFF]'
assert get_symbol(55295) == "\ud88b"
assert get_symbol(55296) == "\ue000"
assert get_symbol(57343) == "\ue7ff"
def test_parse_einsum_input() -> None:
eq = "ab,bc,cd"
ops = build_arrays_from_tuples([(2, 3), (3, 4), (4, 5)])
input_subscripts, output_subscript, operands = parse_einsum_input([eq, *ops])
assert input_subscripts == eq
assert output_subscript == "ad"
assert operands == ops
def test_parse_einsum_input_shapes_error() -> None:
eq = "ab,bc,cd"
ops = build_arrays_from_tuples([(2, 3), (3, 4), (4, 5)])
with pytest.raises(ValueError):
_ = parse_einsum_input([eq, *ops], shapes=True)
def test_parse_einsum_input_shapes() -> None:
eq = "ab,bc,cd"
shapes = [(2, 3), (3, 4), (4, 5)]
input_subscripts, output_subscript, operands = parse_einsum_input([eq, *shapes], shapes=True)
assert input_subscripts == eq
assert output_subscript == "ad"
assert shapes == operands
def test_parse_with_ellisis() -> None:
eq = "...a,ab"
shapes = [(2, 3), (3, 4)]
input_subscripts, output_subscript, operands = parse_einsum_input([eq, *shapes], shapes=True)
assert input_subscripts == "da,ab"
assert output_subscript == "db"
assert shapes == operands
@pytest.mark.parametrize(
"array, shape",
[
[[5], (1,)],
[[5, 5], (2,)],
[(5, 5), (2,)],
[[[[[[5, 2]]]]], (1, 1, 1, 1, 2)],
[[[[[["abcdef", "b"]]]]], (1, 1, 1, 1, 2)],
["A", ()],
[b"A", ()],
[True, ()],
[5, ()],
[5.0, ()],
[5.0 + 0j, ()],
],
)
def test_get_shapes(array: Any, shape: Tuple[int]) -> None:
assert get_shape(array) == shape
@@ -0,0 +1,534 @@
"""
Tests the accuracy of the opt_einsum paths in addition to unit tests for
the various path helper functions.
"""
import itertools
from concurrent.futures import ProcessPoolExecutor
from typing import Any, Dict, List, Optional
import pytest
import opt_einsum as oe
from opt_einsum.testing import build_shapes, rand_equation
from opt_einsum.typing import ArrayIndexType, OptimizeKind, PathType, TensorShapeType
explicit_path_tests = {
"GEMM1": (
[set("abd"), set("ac"), set("bdc")],
set(""),
{"a": 1, "b": 2, "c": 3, "d": 4},
),
"Inner1": (
[set("abcd"), set("abc"), set("bc")],
set(""),
{"a": 5, "b": 2, "c": 3, "d": 4},
),
}
# note that these tests have no unique solution due to the chosen dimensions
path_edge_tests = [
["greedy", "eb,cb,fb->cef", ((0, 2), (0, 1))],
["branch-all", "eb,cb,fb->cef", ((0, 2), (0, 1))],
["branch-2", "eb,cb,fb->cef", ((0, 2), (0, 1))],
["optimal", "eb,cb,fb->cef", ((0, 2), (0, 1))],
["dp", "eb,cb,fb->cef", ((1, 2), (0, 1))],
["greedy", "dd,fb,be,cdb->cef", ((0, 3), (0, 1), (0, 1))],
["branch-all", "dd,fb,be,cdb->cef", ((0, 3), (0, 1), (0, 1))],
["branch-2", "dd,fb,be,cdb->cef", ((0, 3), (0, 1), (0, 1))],
["optimal", "dd,fb,be,cdb->cef", ((0, 3), (0, 1), (0, 1))],
["optimal", "dd,fb,be,cdb->cef", ((0, 3), (0, 1), (0, 1))],
["dp", "dd,fb,be,cdb->cef", ((0, 3), (0, 2), (0, 1))],
["greedy", "bca,cdb,dbf,afc->", ((1, 2), (0, 2), (0, 1))],
["branch-all", "bca,cdb,dbf,afc->", ((1, 2), (0, 2), (0, 1))],
["branch-2", "bca,cdb,dbf,afc->", ((1, 2), (0, 2), (0, 1))],
["optimal", "bca,cdb,dbf,afc->", ((1, 2), (0, 2), (0, 1))],
["dp", "bca,cdb,dbf,afc->", ((1, 2), (1, 2), (0, 1))],
["greedy", "dcc,fce,ea,dbf->ab", ((1, 2), (0, 1), (0, 1))],
["branch-all", "dcc,fce,ea,dbf->ab", ((1, 2), (0, 2), (0, 1))],
["branch-2", "dcc,fce,ea,dbf->ab", ((1, 2), (0, 2), (0, 1))],
["optimal", "dcc,fce,ea,dbf->ab", ((1, 2), (0, 2), (0, 1))],
["dp", "dcc,fce,ea,dbf->ab", ((1, 2), (0, 2), (0, 1))],
]
# note that these tests have no unique solution due to the chosen dimensions
path_scalar_tests = [
[
"a,->a",
1,
],
["ab,->ab", 1],
[",a,->a", 2],
[",,a,->a", 3],
[",,->", 2],
]
def check_path(test_output: PathType, benchmark: PathType, bypass: bool = False) -> bool:
if not isinstance(test_output, list):
return False
if len(test_output) != len(benchmark):
return False
ret = True
for pos in range(len(test_output)):
ret &= isinstance(test_output[pos], tuple)
ret &= test_output[pos] == list(benchmark)[pos]
return ret
def assert_contract_order(func: Any, test_data: Any, max_size: int, benchmark: PathType) -> None:
test_output = func(test_data[0], test_data[1], test_data[2], max_size)
assert check_path(test_output, benchmark)
def test_size_by_dict() -> None:
sizes_dict = {}
for ind, val in zip("abcdez", [2, 5, 9, 11, 13, 0]):
sizes_dict[ind] = val
path_func = oe.helpers.compute_size_by_dict
assert 1 == path_func("", sizes_dict)
assert 2 == path_func("a", sizes_dict)
assert 5 == path_func("b", sizes_dict)
assert 0 == path_func("z", sizes_dict)
assert 0 == path_func("az", sizes_dict)
assert 0 == path_func("zbc", sizes_dict)
assert 104 == path_func("aaae", sizes_dict)
assert 12870 == path_func("abcde", sizes_dict)
def test_flop_cost() -> None:
size_dict = {v: 10 for v in "abcdef"}
# Loop over an array
assert 10 == oe.helpers.flop_count("a", False, 1, size_dict)
# Hadamard product (*)
assert 10 == oe.helpers.flop_count("a", False, 2, size_dict)
assert 100 == oe.helpers.flop_count("ab", False, 2, size_dict)
# Inner product (+, *)
assert 20 == oe.helpers.flop_count("a", True, 2, size_dict)
assert 200 == oe.helpers.flop_count("ab", True, 2, size_dict)
# Inner product x3 (+, *, *)
assert 30 == oe.helpers.flop_count("a", True, 3, size_dict)
# GEMM
assert 2000 == oe.helpers.flop_count("abc", True, 2, size_dict)
def test_bad_path_option() -> None:
with pytest.raises(KeyError):
oe.contract("a,b,c", [1], [2], [3], optimize="optimall", shapes=True) # type: ignore
def test_explicit_path() -> None:
pytest.importorskip("numpy")
x = oe.contract("a,b,c", [1], [2], [3], optimize=[(1, 2), (0, 1)])
assert x.item() == 6
def test_path_optimal() -> None:
test_func = oe.paths.optimal
test_data = explicit_path_tests["GEMM1"]
assert_contract_order(test_func, test_data, 5000, [(0, 2), (0, 1)])
assert_contract_order(test_func, test_data, 0, [(0, 1, 2)])
def test_path_greedy() -> None:
test_func = oe.paths.greedy
test_data = explicit_path_tests["GEMM1"]
assert_contract_order(test_func, test_data, 5000, [(0, 2), (0, 1)])
assert_contract_order(test_func, test_data, 0, [(0, 1, 2)])
def test_memory_paths() -> None:
expression = "abc,bdef,fghj,cem,mhk,ljk->adgl"
views = build_shapes(expression)
# Test tiny memory limit
path_ret = oe.contract_path(expression, *views, optimize="optimal", memory_limit=5, shapes=True)
assert check_path(path_ret[0], [(0, 1, 2, 3, 4, 5)])
path_ret = oe.contract_path(expression, *views, optimize="greedy", memory_limit=5, shapes=True)
assert check_path(path_ret[0], [(0, 1, 2, 3, 4, 5)])
# Check the possibilities, greedy is capped
path_ret = oe.contract_path(expression, *views, optimize="optimal", memory_limit=-1, shapes=True)
assert check_path(path_ret[0], [(0, 3), (0, 4), (0, 2), (0, 2), (0, 1)])
path_ret = oe.contract_path(expression, *views, optimize="greedy", memory_limit=-1, shapes=True)
assert check_path(path_ret[0], [(0, 3), (0, 4), (0, 2), (0, 2), (0, 1)])
@pytest.mark.parametrize("alg,expression,order", path_edge_tests)
def test_path_edge_cases(alg: OptimizeKind, expression: str, order: PathType) -> None:
views = build_shapes(expression)
# Test tiny memory limit
path_ret = oe.contract_path(expression, *views, optimize=alg, shapes=True)
assert check_path(path_ret[0], order)
@pytest.mark.parametrize("expression,order", path_scalar_tests)
@pytest.mark.parametrize("alg", oe.paths._PATH_OPTIONS)
def test_path_scalar_cases(alg: OptimizeKind, expression: str, order: PathType) -> None:
views = build_shapes(expression)
# Test tiny memory limit
path_ret = oe.contract_path(expression, *views, optimize=alg, shapes=True)
# print(path_ret[0])
assert len(path_ret[0]) == order
def test_optimal_edge_cases() -> None:
# Edge test5
expression = "a,ac,ab,ad,cd,bd,bc->"
edge_test4 = build_shapes(expression, dimension_dict={"a": 20, "b": 20, "c": 20, "d": 20})
path, _ = oe.contract_path(expression, *edge_test4, optimize="greedy", memory_limit="max_input", shapes=True)
assert check_path(path, [(0, 1), (0, 1, 2, 3, 4, 5)])
path, _ = oe.contract_path(expression, *edge_test4, optimize="optimal", memory_limit="max_input", shapes=True)
assert check_path(path, [(0, 1), (0, 1, 2, 3, 4, 5)])
def test_greedy_edge_cases() -> None:
expression = "abc,cfd,dbe,efa"
dim_dict = {k: 20 for k in expression.replace(",", "")}
tensors = build_shapes(expression, dimension_dict=dim_dict)
path, _ = oe.contract_path(expression, *tensors, optimize="greedy", memory_limit="max_input", shapes=True)
assert check_path(path, [(0, 1, 2, 3)])
path, _ = oe.contract_path(expression, *tensors, optimize="greedy", memory_limit=-1, shapes=True)
assert check_path(path, [(0, 1), (0, 2), (0, 1)])
def test_dp_edge_cases_dimension_1() -> None:
eq = "nlp,nlq,pl->n"
shapes = [(1, 1, 1), (1, 1, 1), (1, 1)]
info = oe.contract_path(eq, *shapes, shapes=True, optimize="dp")[1]
assert max(info.scale_list) == 3
def test_dp_edge_cases_all_singlet_indices() -> None:
eq = "a,bcd,efg->"
shapes = [(2,), (2, 2, 2), (2, 2, 2)]
info = oe.contract_path(eq, *shapes, shapes=True, optimize="dp")[1]
assert max(info.scale_list) == 3
def test_custom_dp_can_optimize_for_outer_products() -> None:
eq = "a,b,abc->c"
da, db, dc = 2, 2, 3
shapes = [(da,), (db,), (da, db, dc)]
opt1 = oe.DynamicProgramming(search_outer=False)
opt2 = oe.DynamicProgramming(search_outer=True)
info1 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt1)[1]
info2 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt2)[1]
assert info2.opt_cost < info1.opt_cost
def test_custom_dp_can_optimize_for_size() -> None:
eq, shapes = rand_equation(10, 4, seed=43)
opt1 = oe.DynamicProgramming(minimize="flops")
opt2 = oe.DynamicProgramming(minimize="size")
info1 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt1)[1]
info2 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt2)[1]
assert info1.opt_cost < info2.opt_cost
assert info1.largest_intermediate > info2.largest_intermediate
def test_custom_dp_can_set_cost_cap() -> None:
eq, shapes = rand_equation(5, 3, seed=42)
opt1 = oe.DynamicProgramming(cost_cap=True)
opt2 = oe.DynamicProgramming(cost_cap=False)
opt3 = oe.DynamicProgramming(cost_cap=100)
info1 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt1)[1]
info2 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt2)[1]
info3 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt3)[1]
assert info1.opt_cost == info2.opt_cost == info3.opt_cost
@pytest.mark.parametrize(
"minimize,cost,width,path",
[
("flops", 663054, 18900, [(4, 5), (2, 5), (2, 7), (5, 6), (1, 5), (1, 4), (0, 3), (0, 2), (0, 1)]),
("size", 1114440, 2016, [(2, 7), (3, 8), (3, 7), (2, 6), (1, 5), (1, 4), (1, 3), (1, 2), (0, 1)]),
("write", 983790, 2016, [(0, 8), (3, 4), (1, 4), (5, 6), (1, 5), (0, 4), (0, 3), (1, 2), (0, 1)]),
("combo", 973518, 2016, [(4, 5), (2, 5), (6, 7), (2, 6), (1, 5), (1, 4), (0, 3), (0, 2), (0, 1)]),
("limit", 983832, 2016, [(2, 7), (3, 4), (0, 4), (3, 6), (2, 5), (0, 4), (0, 3), (1, 2), (0, 1)]),
("combo-256", 983790, 2016, [(0, 8), (3, 4), (1, 4), (5, 6), (1, 5), (0, 4), (0, 3), (1, 2), (0, 1)]),
("limit-256", 983832, 2016, [(2, 7), (3, 4), (0, 4), (3, 6), (2, 5), (0, 4), (0, 3), (1, 2), (0, 1)]),
],
)
def test_custom_dp_can_set_minimize(minimize: str, cost: int, width: int, path: PathType) -> None:
eq, shapes = rand_equation(10, 4, seed=43)
opt = oe.DynamicProgramming(minimize=minimize)
info = oe.contract_path(eq, *shapes, shapes=True, optimize=opt)[1]
assert info.path == path
assert info.opt_cost == cost
assert info.largest_intermediate == width
def test_dp_errors_when_no_contractions_found() -> None:
eq, shapes = rand_equation(10, 3, seed=42)
# first get the actual minimum cost
opt = oe.DynamicProgramming(minimize="size")
_, info = oe.contract_path(eq, *shapes, shapes=True, optimize=opt)
mincost = info.largest_intermediate
# check we can still find it without minimizing size explicitly
oe.contract_path(eq, *shapes, shapes=True, memory_limit=mincost, optimize="dp")
# but check just below this threshold raises
with pytest.raises(RuntimeError):
oe.contract_path(eq, *shapes, shapes=True, memory_limit=mincost - 1, optimize="dp")
@pytest.mark.parametrize("optimize", ["greedy", "branch-2", "branch-all", "optimal", "dp"])
def test_can_optimize_outer_products(optimize: OptimizeKind) -> None:
a, b, c = ((10, 10) for _ in range(3))
d = (10, 2)
assert oe.contract_path("ab,cd,ef,fg", a, b, c, d, optimize=optimize, shapes=True)[0] == [
(2, 3),
(0, 2),
(0, 1),
]
@pytest.mark.parametrize("num_symbols", [2, 3, 26, 26 + 26, 256 - 140, 300])
def test_large_path(num_symbols: int) -> None:
symbols = "".join(oe.get_symbol(i) for i in range(num_symbols))
dimension_dict = dict(zip(symbols, itertools.cycle([2, 3, 4])))
expression = ",".join(symbols[t : t + 2] for t in range(num_symbols - 1))
tensors = build_shapes(expression, dimension_dict=dimension_dict)
# Check that path construction does not crash
oe.contract_path(expression, *tensors, optimize="greedy", shapes=True)
def test_custom_random_greedy() -> None:
np = pytest.importorskip("numpy")
eq, shapes = rand_equation(10, 4, seed=42)
views = list(map(np.ones, shapes))
with pytest.raises(ValueError):
oe.RandomGreedy(minimize="something")
optimizer = oe.RandomGreedy(max_repeats=10, minimize="flops")
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
assert len(optimizer.costs) == 10
assert len(optimizer.sizes) == 10
assert path == optimizer.path
assert optimizer.best["flops"] == min(optimizer.costs)
assert path_info.largest_intermediate == optimizer.best["size"]
assert path_info.opt_cost == optimizer.best["flops"]
# check can change settings and run again
optimizer.temperature = 0.0
optimizer.max_repeats = 6
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
assert len(optimizer.costs) == 16
assert len(optimizer.sizes) == 16
assert path == optimizer.path
assert optimizer.best["size"] == min(optimizer.sizes)
assert path_info.largest_intermediate == optimizer.best["size"]
assert path_info.opt_cost == optimizer.best["flops"]
# check error if we try and reuse the optimizer on a different expression
eq, shapes = rand_equation(10, 4, seed=41)
views = list(map(np.ones, shapes))
with pytest.raises(ValueError):
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
def test_custom_branchbound() -> None:
np = pytest.importorskip("numpy")
eq, shapes = rand_equation(8, 4, seed=42)
views = list(map(np.ones, shapes))
optimizer = oe.BranchBound(nbranch=2, cutoff_flops_factor=10, minimize="size")
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
assert path == optimizer.path
assert path_info.largest_intermediate == optimizer.best["size"]
assert path_info.opt_cost == optimizer.best["flops"]
# tweak settings and run again
optimizer.nbranch = 3
optimizer.cutoff_flops_factor = 4
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
assert path == optimizer.path
assert path_info.largest_intermediate == optimizer.best["size"]
assert path_info.opt_cost == optimizer.best["flops"]
# check error if we try and reuse the optimizer on a different expression
eq, shapes = rand_equation(8, 4, seed=41)
views = list(map(np.ones, shapes))
with pytest.raises(ValueError):
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
def test_branchbound_validation() -> None:
with pytest.raises(ValueError):
oe.BranchBound(nbranch=0)
def test_parallel_random_greedy() -> None:
np = pytest.importorskip("numpy")
pool = ProcessPoolExecutor(2)
eq, shapes = rand_equation(10, 4, seed=42)
views = list(map(np.ones, shapes))
optimizer = oe.RandomGreedy(max_repeats=10, parallel=pool)
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
assert len(optimizer.costs) == 10
assert len(optimizer.sizes) == 10
assert path == optimizer.path
assert optimizer.parallel is pool
assert optimizer._executor is pool
assert optimizer.best["flops"] == min(optimizer.costs)
assert path_info.largest_intermediate == optimizer.best["size"]
assert path_info.opt_cost == optimizer.best["flops"]
# now switch to max time algorithm
optimizer.max_repeats = int(1e6)
optimizer.max_time = 0.2
optimizer.parallel = 2
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
assert len(optimizer.costs) > 10
assert len(optimizer.sizes) > 10
assert path == optimizer.path
assert optimizer.best["flops"] == min(optimizer.costs)
assert path_info.largest_intermediate == optimizer.best["size"]
assert path_info.opt_cost == optimizer.best["flops"]
optimizer.parallel = True
assert optimizer._executor is not None
assert optimizer._executor is not pool
are_done = [f.running() or f.done() for f in optimizer._futures]
assert all(are_done)
def test_custom_path_optimizer() -> None:
np = pytest.importorskip("numpy")
class NaiveOptimizer(oe.paths.PathOptimizer):
def __call__(
self,
inputs: List[ArrayIndexType],
output: ArrayIndexType,
size_dict: Dict[str, int],
memory_limit: Optional[int] = None,
) -> PathType:
self.was_used = True
return [(0, 1)] * (len(inputs) - 1)
eq, shapes = rand_equation(5, 3, seed=42, d_max=3)
views = list(map(np.ones, shapes))
exp = oe.contract(eq, *views, optimize=False)
optimizer = NaiveOptimizer()
out = oe.contract(eq, *views, optimize=optimizer)
assert exp == out
assert optimizer.was_used
def test_custom_random_optimizer() -> None:
np = pytest.importorskip("numpy")
class NaiveRandomOptimizer(oe.path_random.RandomOptimizer):
@staticmethod
def random_path(
r: int, n: int, inputs: List[ArrayIndexType], output: ArrayIndexType, size_dict: Dict[str, int]
) -> Any:
"""Picks a completely random contraction order."""
np.random.seed(r)
ssa_path: List[TensorShapeType] = []
remaining = set(range(n))
while len(remaining) > 1:
i, j = np.random.choice(list(remaining), size=2, replace=False)
remaining.add(n + len(ssa_path))
remaining.remove(i)
remaining.remove(j)
ssa_path.append((i, j))
cost, size = oe.path_random.ssa_path_compute_cost(ssa_path, inputs, output, size_dict)
return ssa_path, cost, size
def setup(self, inputs: Any, output: Any, size_dict: Any) -> Any:
self.was_used = True
n = len(inputs)
trial_fn = self.random_path
trial_args = (n, inputs, output, size_dict)
return trial_fn, trial_args
eq, shapes = rand_equation(5, 3, seed=42, d_max=3)
views = list(map(np.ones, shapes))
exp = oe.contract(eq, *views, optimize=False)
optimizer = NaiveRandomOptimizer(max_repeats=16)
out = oe.contract(eq, *views, optimize=optimizer)
assert exp == out
assert optimizer.was_used
assert len(optimizer.costs) == 16
def test_optimizer_registration() -> None:
def custom_optimizer(
inputs: List[ArrayIndexType], output: ArrayIndexType, size_dict: Dict[str, int], memory_limit: Optional[int]
) -> PathType:
return [(0, 1)] * (len(inputs) - 1)
with pytest.raises(KeyError):
oe.paths.register_path_fn("optimal", custom_optimizer)
oe.paths.register_path_fn("custom", custom_optimizer)
assert "custom" in oe.paths._PATH_OPTIONS
eq = "ab,bc,cd"
shapes = [(2, 3), (3, 4), (4, 5)]
path, _ = oe.contract_path(eq, *shapes, shapes=True, optimize="custom") # type: ignore
assert path == [(0, 1), (0, 1)]
del oe.paths._PATH_OPTIONS["custom"]
def test_path_with_assumed_shapes() -> None:
path, _ = oe.contract_path("ab,bc,cd", [[5, 3]], [[2], [4]], [[3, 2]])
assert path == [(0, 1), (0, 1)]
@@ -0,0 +1,390 @@
import itertools
import weakref
from collections import Counter
from typing import Any
import pytest
from opt_einsum import contract, contract_expression, contract_path, get_symbol, shared_intermediates
from opt_einsum.backends import to_cupy, to_torch
from opt_einsum.contract import _einsum
from opt_einsum.parser import parse_einsum_input
from opt_einsum.sharing import count_cached_ops, currently_sharing, get_sharing_cache
from opt_einsum.testing import build_views
from opt_einsum.typing import BackendType
pytest.importorskip("numpy")
try:
import numpy as np # type: ignore
numpy_if_found = "numpy"
except ImportError:
numpy_if_found = pytest.param("numpy", marks=[pytest.mark.skip(reason="NumPy not installed.")]) # type: ignore
try:
import cupy # noqa
cupy_if_found = "cupy"
except ImportError:
cupy_if_found = pytest.param("cupy", marks=[pytest.mark.skip(reason="CuPy not installed.")]) # type: ignore
try:
import torch # type: ignore # noqa
torch_if_found = "torch"
except ImportError:
torch_if_found = pytest.param("torch", marks=[pytest.mark.skip(reason="PyTorch not installed.")]) # type: ignore
backends = [numpy_if_found, torch_if_found, cupy_if_found]
equations = [
"ab,bc->ca",
"abc,bcd,dea",
"abc,def->fedcba",
"abc,bcd,df->fa",
# test 'prefer einsum' ops
"ijk,ikj",
"i,j->ij",
"ijk,k->ij",
"AB,BC->CA",
]
to_backend = {
"numpy": lambda x: x,
"torch": to_torch,
"cupy": to_cupy,
}
@pytest.mark.parametrize("eq", equations)
@pytest.mark.parametrize("backend", backends)
def test_sharing_value(eq: str, backend: BackendType) -> None:
views = build_views(eq)
shapes = [v.shape for v in views]
expr = contract_expression(eq, *shapes)
expected = expr(*views, backend=backend)
with shared_intermediates():
actual = expr(*views, backend=backend)
assert (actual == expected).all()
@pytest.mark.parametrize("backend", backends)
def test_complete_sharing(backend: BackendType) -> None:
eq = "ab,bc,cd->"
views = build_views(eq)
expr = contract_expression(eq, *(v.shape for v in views))
print("-" * 40)
print("Without sharing:")
with shared_intermediates() as cache:
expr(*views, backend=backend)
expected = count_cached_ops(cache)
print("-" * 40)
print("With sharing:")
with shared_intermediates() as cache:
expr(*views, backend=backend)
expr(*views, backend=backend)
actual = count_cached_ops(cache)
print("-" * 40)
print(f"Without sharing: {expected} expressions")
print(f"With sharing: {actual} expressions")
assert actual == expected
@pytest.mark.parametrize("backend", backends)
def test_sharing_reused_cache(backend: BackendType) -> None:
eq = "ab,bc,cd->"
views = build_views(eq)
expr = contract_expression(eq, *(v.shape for v in views))
print("-" * 40)
print("Without sharing:")
with shared_intermediates() as cache:
expr(*views, backend=backend)
expected = count_cached_ops(cache)
print("-" * 40)
print("With sharing:")
with shared_intermediates() as cache:
expr(*views, backend=backend)
with shared_intermediates(cache):
expr(*views, backend=backend)
actual = count_cached_ops(cache)
print("-" * 40)
print(f"Without sharing: {expected} expressions")
print(f"With sharing: {actual} expressions")
assert actual == expected
@pytest.mark.parametrize("backend", backends)
def test_no_sharing_separate_cache(backend: BackendType) -> None:
eq = "ab,bc,cd->"
views = build_views(eq)
expr = contract_expression(eq, *(v.shape for v in views))
print("-" * 40)
print("Without sharing:")
with shared_intermediates() as cache:
expr(*views, backend=backend)
expected = count_cached_ops(cache)
expected.update(count_cached_ops(cache)) # we expect double
print("-" * 40)
print("With sharing:")
with shared_intermediates() as cache1:
expr(*views, backend=backend)
actual = count_cached_ops(cache1)
with shared_intermediates() as cache2:
expr(*views, backend=backend)
actual.update(count_cached_ops(cache2))
print("-" * 40)
print(f"Without sharing: {expected} expressions")
print(f"With sharing: {actual} expressions")
assert actual == expected
@pytest.mark.parametrize("backend", backends)
def test_sharing_nesting(backend: BackendType) -> None:
eqs = ["ab,bc,cd->a", "ab,bc,cd->b", "ab,bc,cd->c", "ab,bc,cd->c"]
views = build_views(eqs[0])
shapes = [v.shape for v in views]
refs: Any = weakref.WeakValueDictionary()
def method1(views):
with shared_intermediates():
w = contract_expression(eqs[0], *shapes)(*views, backend=backend)
x = contract_expression(eqs[2], *shapes)(*views, backend=backend)
result = contract_expression("a,b->", w.shape, x.shape)(w, x, backend=backend)
refs["w"] = w
refs["x"] = x
del w, x
assert "w" in refs
assert "x" in refs
assert "w" not in refs, "cache leakage"
assert "x" not in refs, "cache leakage"
return result
def method2(views):
with shared_intermediates():
y = contract_expression(eqs[2], *shapes)(*views, backend=backend)
z = contract_expression(eqs[3], *shapes)(*views, backend=backend)
refs["y"] = y
refs["z"] = z
result = contract_expression("c,d->", y.shape, z.shape)(y, z, backend=backend)
result = result + method1(views) # nest method1 in method2
del y, z
assert "y" in refs
assert "z" in refs
assert "y" not in refs
assert "z" not in refs
method1(views)
method2(views)
@pytest.mark.parametrize("eq", equations)
@pytest.mark.parametrize("backend", backends)
def test_sharing_modulo_commutativity(eq: str, backend: BackendType) -> None:
ops = tuple(to_backend[backend](x) for x in build_views(eq))
inputs, output, _ = parse_einsum_input([eq] + list(ops))
inputs_list = inputs.split(",")
print("-" * 40)
print("Without sharing:")
with shared_intermediates() as cache:
_einsum(eq, *ops, backend=backend)
expected = count_cached_ops(cache)
print("-" * 40)
print("With sharing:")
with shared_intermediates() as cache:
for permuted in itertools.permutations(zip(inputs_list, ops)):
permuted_inputs = [p[0] for p in permuted]
permuted_ops = [p[1] for p in permuted]
permuted_eq = "{}->{}".format(",".join(permuted_inputs), output)
_einsum(permuted_eq, *permuted_ops, backend=backend)
actual = count_cached_ops(cache)
print("-" * 40)
print(f"Without sharing: {expected} expressions")
print(f"With sharing: {actual} expressions")
assert actual == expected
@pytest.mark.parametrize("backend", backends)
def test_partial_sharing(backend: BackendType) -> None:
eq = "ab,bc,de->"
x, y, z1 = build_views(eq) # type: ignore
z2 = 2.0 * z1 - 1.0
expr = contract_expression(eq, x.shape, y.shape, z1.shape)
print("-" * 40)
print("Without sharing:")
num_exprs_nosharing: Any = Counter()
with shared_intermediates() as cache:
expr(x, y, z1, backend=backend)
num_exprs_nosharing.update(count_cached_ops(cache))
with shared_intermediates() as cache:
expr(x, y, z2, backend=backend)
num_exprs_nosharing.update(count_cached_ops(cache))
print("-" * 40)
print("With sharing:")
with shared_intermediates() as cache:
expr(x, y, z1, backend=backend)
expr(x, y, z2, backend=backend)
num_exprs_sharing = count_cached_ops(cache)
print("-" * 40)
print(f"Without sharing: {num_exprs_nosharing} expressions")
print(f"With sharing: {num_exprs_sharing} expressions")
assert num_exprs_nosharing["einsum"] > num_exprs_sharing["einsum"]
@pytest.mark.parametrize("backend", backends)
def test_sharing_with_constants(backend: BackendType) -> None:
inputs = "ij,jk,kl"
outputs = "ijkl"
equations = [f"{inputs}->{output}" for output in outputs]
shapes = (2, 3), (3, 4), (4, 5)
constants = {0, 2}
ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)]
var = np.random.rand(*shapes[1])
expected = [contract_expression(eq, *shapes)(ops[0], var, ops[2]) for eq in equations]
with shared_intermediates():
actual = [contract_expression(eq, *ops, constants=constants)(var) for eq in equations]
for dim, expected_dim, actual_dim in zip(outputs, expected, actual):
assert np.allclose(expected_dim, actual_dim), f"error at {dim}"
@pytest.mark.parametrize("size", [3, 4, 5])
@pytest.mark.parametrize("backend", backends)
def test_chain(size: int, backend: BackendType) -> None:
xs = [np.random.rand(2, 2) for _ in range(size)]
shapes = [x.shape for x in xs]
alphabet = "".join(get_symbol(i) for i in range(size + 1))
names = [alphabet[i : i + 2] for i in range(size)]
inputs = ",".join(names)
with shared_intermediates():
print(inputs)
for i in range(size + 1):
target = alphabet[i]
eq = f"{inputs}->{target}"
path_info = contract_path(eq, *xs)
print(path_info[1])
expr = contract_expression(eq, *shapes)
expr(*xs, backend=backend)
print("-" * 40)
@pytest.mark.parametrize("size", [3, 4, 5, 10])
@pytest.mark.parametrize("backend", backends)
def test_chain_2(size: int, backend: BackendType) -> None:
xs = [np.random.rand(2, 2) for _ in range(size)]
shapes = [x.shape for x in xs]
alphabet = "".join(get_symbol(i) for i in range(size + 1))
names = [alphabet[i : i + 2] for i in range(size)]
inputs = ",".join(names)
with shared_intermediates():
print(inputs)
for i in range(size):
target = alphabet[i : i + 2]
eq = f"{inputs}->{target}"
path_info = contract_path(eq, *xs)
print(path_info[1])
expr = contract_expression(eq, *shapes)
expr(*xs, backend=backend)
print("-" * 40)
def _compute_cost(cache):
counts = count_cached_ops(cache)
return counts["einsum"] + counts["tensordot"]
@pytest.mark.parametrize("backend", backends)
def test_chain_2_growth(backend: BackendType) -> None:
sizes = list(range(1, 21))
costs = []
for size in sizes:
xs = [np.random.rand(2, 2) for _ in range(size)]
alphabet = "".join(get_symbol(i) for i in range(size + 1))
names = [alphabet[i : i + 2] for i in range(size)]
inputs = ",".join(names)
with shared_intermediates() as cache:
for i in range(size):
target = alphabet[i : i + 2]
eq = f"{inputs}->{target}"
expr = contract_expression(eq, *(x.shape for x in xs))
expr(*xs, backend=backend)
costs.append(_compute_cost(cache))
print(f"sizes = {repr(sizes)}")
print(f"costs = {repr(costs)}")
for size, cost in zip(sizes, costs):
print(f"{size}\t{cost}")
@pytest.mark.parametrize("size", [3, 4, 5])
@pytest.mark.parametrize("backend", backends)
def test_chain_sharing(size: int, backend: BackendType) -> None:
xs = [np.random.rand(2, 2) for _ in range(size)]
alphabet = "".join(get_symbol(i) for i in range(size + 1))
names = [alphabet[i : i + 2] for i in range(size)]
inputs = ",".join(names)
num_exprs_nosharing = 0
for i in range(size + 1):
with shared_intermediates() as cache:
target = alphabet[i]
eq = f"{inputs}->{target}"
expr = contract_expression(eq, *tuple(x.shape for x in xs))
expr(*xs, backend=backend)
num_exprs_nosharing += _compute_cost(cache)
with shared_intermediates() as cache:
print(inputs)
for i in range(size + 1):
target = alphabet[i]
eq = f"{inputs}->{target}"
path_info = contract_path(eq, *xs)
print(path_info[1])
expr = contract_expression(eq, *[x.shape for x in xs])
expr(*xs, backend=backend)
num_exprs_sharing = _compute_cost(cache)
print("-" * 40)
print(f"Without sharing: {num_exprs_nosharing} expressions")
print(f"With sharing: {num_exprs_sharing} expressions")
assert num_exprs_nosharing > num_exprs_sharing
def test_multithreaded_sharing() -> None:
from multiprocessing.pool import ThreadPool
def fn():
x, y, z = build_views("ab,bc,cd")
with shared_intermediates():
contract("ab,bc,cd->a", x, y, z)
contract("ab,bc,cd->b", x, y, z)
return len(get_sharing_cache())
expected = fn()
pool = ThreadPool(8)
fs = [pool.apply_async(fn) for _ in range(16)]
assert not currently_sharing()
assert [f.get() for f in fs] == [expected] * 16
pool.close()
@@ -0,0 +1,27 @@
"""Types used in the opt_einsum package."""
from collections import namedtuple
from typing import Any, Callable, Collection, Dict, FrozenSet, List, Literal, Optional, Tuple, Union
TensorShapeType = Tuple[int, ...]
PathType = Collection[TensorShapeType]
ArrayType = Any
ArrayIndexType = FrozenSet[str]
ArrayShaped = namedtuple("ArrayShaped", ["shape"])
ContractionListType = List[Tuple[Any, ArrayIndexType, str, Optional[Tuple[str, ...]], Union[str, bool]]]
PathSearchFunctionType = Callable[[List[ArrayIndexType], ArrayIndexType, Dict[str, int], Optional[int]], PathType]
# Contract kwargs
OptimizeKind = Union[
None,
bool,
Literal[
"optimal", "dp", "greedy", "random-greedy", "random-greedy-128", "branch-all", "branch-2", "auto", "auto-hq"
],
PathType,
PathSearchFunctionType,
]
BackendType = Literal["auto", "object", "autograd", "cupy", "dask", "jax", "theano", "tensorflow", "torch", "libjax"]