hand
This commit is contained in:
@@ -0,0 +1,29 @@
|
||||
"""Main init function for opt_einsum."""
|
||||
|
||||
from opt_einsum import blas, helpers, path_random, paths
|
||||
from opt_einsum._version import __version__
|
||||
from opt_einsum.contract import contract, contract_expression, contract_path
|
||||
from opt_einsum.parser import get_symbol
|
||||
from opt_einsum.path_random import RandomGreedy
|
||||
from opt_einsum.paths import BranchBound, DynamicProgramming
|
||||
from opt_einsum.sharing import shared_intermediates
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"blas",
|
||||
"helpers",
|
||||
"path_random",
|
||||
"paths",
|
||||
"contract",
|
||||
"contract_expression",
|
||||
"contract_path",
|
||||
"get_symbol",
|
||||
"RandomGreedy",
|
||||
"BranchBound",
|
||||
"DynamicProgramming",
|
||||
"shared_intermediates",
|
||||
]
|
||||
|
||||
|
||||
paths.register_path_fn("random-greedy", path_random.random_greedy)
|
||||
paths.register_path_fn("random-greedy-128", path_random.random_greedy_128)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,16 @@
|
||||
# file generated by setuptools_scm
|
||||
# don't change, don't track in version control
|
||||
TYPE_CHECKING = False
|
||||
if TYPE_CHECKING:
|
||||
from typing import Tuple, Union
|
||||
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
||||
else:
|
||||
VERSION_TUPLE = object
|
||||
|
||||
version: str
|
||||
__version__: str
|
||||
__version_tuple__: VERSION_TUPLE
|
||||
version_tuple: VERSION_TUPLE
|
||||
|
||||
__version__ = version = '3.4.0'
|
||||
__version_tuple__ = version_tuple = (3, 4, 0)
|
||||
@@ -0,0 +1,28 @@
|
||||
"""Compute backends for opt_einsum."""
|
||||
|
||||
# Backends
|
||||
from opt_einsum.backends.cupy import to_cupy
|
||||
from opt_einsum.backends.dispatch import (
|
||||
build_expression,
|
||||
evaluate_constants,
|
||||
get_func,
|
||||
has_backend,
|
||||
has_einsum,
|
||||
has_tensordot,
|
||||
)
|
||||
from opt_einsum.backends.tensorflow import to_tensorflow
|
||||
from opt_einsum.backends.theano import to_theano
|
||||
from opt_einsum.backends.torch import to_torch
|
||||
|
||||
__all__ = [
|
||||
"get_func",
|
||||
"has_einsum",
|
||||
"has_tensordot",
|
||||
"build_expression",
|
||||
"evaluate_constants",
|
||||
"has_backend",
|
||||
"to_tensorflow",
|
||||
"to_theano",
|
||||
"to_cupy",
|
||||
"to_torch",
|
||||
]
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,32 @@
|
||||
"""Required functions for optimized contractions of numpy arrays using cupy."""
|
||||
|
||||
from opt_einsum.helpers import has_array_interface
|
||||
from opt_einsum.sharing import to_backend_cache_wrap
|
||||
|
||||
__all__ = ["to_cupy", "build_expression", "evaluate_constants"]
|
||||
|
||||
|
||||
@to_backend_cache_wrap
|
||||
def to_cupy(array): # pragma: no cover
|
||||
import cupy
|
||||
|
||||
if has_array_interface(array):
|
||||
return cupy.asarray(array)
|
||||
|
||||
return array
|
||||
|
||||
|
||||
def build_expression(_, expr): # pragma: no cover
|
||||
"""Build a cupy function based on ``arrays`` and ``expr``."""
|
||||
|
||||
def cupy_contract(*arrays):
|
||||
return expr._contract([to_cupy(x) for x in arrays], backend="cupy").get()
|
||||
|
||||
return cupy_contract
|
||||
|
||||
|
||||
def evaluate_constants(const_arrays, expr): # pragma: no cover
|
||||
"""Convert constant arguments to cupy arrays, and perform any possible
|
||||
constant contractions.
|
||||
"""
|
||||
return expr(*[to_cupy(x) for x in const_arrays], backend="cupy", evaluate_constants=True)
|
||||
@@ -0,0 +1,156 @@
|
||||
"""Handles dispatching array operations to the correct backend library, as well
|
||||
as converting arrays to backend formats and then potentially storing them as
|
||||
constants.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
from opt_einsum.backends import cupy as _cupy
|
||||
from opt_einsum.backends import jax as _jax
|
||||
from opt_einsum.backends import object_arrays
|
||||
from opt_einsum.backends import tensorflow as _tensorflow
|
||||
from opt_einsum.backends import theano as _theano
|
||||
from opt_einsum.backends import torch as _torch
|
||||
|
||||
__all__ = [
|
||||
"get_func",
|
||||
"has_einsum",
|
||||
"has_tensordot",
|
||||
"build_expression",
|
||||
"evaluate_constants",
|
||||
"has_backend",
|
||||
]
|
||||
|
||||
# known non top-level imports
|
||||
_aliases = {
|
||||
"dask": "dask.array",
|
||||
"theano": "theano.tensor",
|
||||
"torch": "opt_einsum.backends.torch",
|
||||
"jax": "jax.numpy",
|
||||
"jaxlib": "jax.numpy",
|
||||
"autograd": "autograd.numpy",
|
||||
"mars": "mars.tensor",
|
||||
}
|
||||
|
||||
|
||||
def _import_func(func: str, backend: str, default: Any = None) -> Any:
|
||||
"""Try and import ``{backend}.{func}``.
|
||||
If library is installed and func is found, return the func;
|
||||
otherwise if default is provided, return default;
|
||||
otherwise raise an error.
|
||||
"""
|
||||
try:
|
||||
lib = importlib.import_module(_aliases.get(backend, backend))
|
||||
return getattr(lib, func) if default is None else getattr(lib, func, default)
|
||||
except AttributeError:
|
||||
error_msg = (
|
||||
"{} doesn't seem to provide the function {} - see "
|
||||
"https://optimized-einsum.readthedocs.io/en/latest/backends.html "
|
||||
"for details on which functions are required for which contractions."
|
||||
)
|
||||
raise AttributeError(error_msg.format(backend, func))
|
||||
|
||||
|
||||
# manually cache functions as python2 doesn't support functools.lru_cache
|
||||
# other libs will be added to this if needed, but pre-populate with numpy
|
||||
_cached_funcs: Dict[Tuple[str, str], Any] = {
|
||||
("einsum", "object"): object_arrays.object_einsum,
|
||||
}
|
||||
|
||||
try:
|
||||
import numpy as np # type: ignore
|
||||
|
||||
_cached_funcs[("tensordot", "numpy")] = np.tensordot
|
||||
_cached_funcs[("transpose", "numpy")] = np.transpose
|
||||
_cached_funcs[("einsum", "numpy")] = np.einsum
|
||||
# also pre-populate with the arbitrary object backend
|
||||
_cached_funcs[("tensordot", "object")] = np.tensordot
|
||||
_cached_funcs[("transpose", "object")] = np.transpose
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
def get_func(func: str, backend: str = "numpy", default: Any = None) -> Any:
|
||||
"""Return ``{backend}.{func}``, e.g. ``numpy.einsum``,
|
||||
or a default func if provided. Cache result.
|
||||
"""
|
||||
try:
|
||||
return _cached_funcs[func, backend]
|
||||
except KeyError:
|
||||
fn = _import_func(func, backend, default)
|
||||
_cached_funcs[func, backend] = fn
|
||||
return fn
|
||||
|
||||
|
||||
# mark libs with einsum, else try to use tensordot/transpose as much as possible
|
||||
_has_einsum: Dict[str, bool] = {}
|
||||
|
||||
|
||||
def has_einsum(backend: str) -> bool:
|
||||
"""Check if ``{backend}.einsum`` exists, cache result for performance."""
|
||||
try:
|
||||
return _has_einsum[backend]
|
||||
except KeyError:
|
||||
try:
|
||||
get_func("einsum", backend)
|
||||
_has_einsum[backend] = True
|
||||
except AttributeError:
|
||||
_has_einsum[backend] = False
|
||||
|
||||
return _has_einsum[backend]
|
||||
|
||||
|
||||
_has_tensordot: Dict[str, bool] = {}
|
||||
|
||||
|
||||
def has_tensordot(backend: str) -> bool:
|
||||
"""Check if ``{backend}.tensordot`` exists, cache result for performance."""
|
||||
try:
|
||||
return _has_tensordot[backend]
|
||||
except KeyError:
|
||||
try:
|
||||
get_func("tensordot", backend)
|
||||
_has_tensordot[backend] = True
|
||||
except AttributeError:
|
||||
_has_tensordot[backend] = False
|
||||
|
||||
return _has_tensordot[backend]
|
||||
|
||||
|
||||
# Dispatch to correct expression backend
|
||||
# these are the backends which support explicit to-and-from numpy conversion
|
||||
CONVERT_BACKENDS = {
|
||||
"tensorflow": _tensorflow.build_expression,
|
||||
"theano": _theano.build_expression,
|
||||
"cupy": _cupy.build_expression,
|
||||
"torch": _torch.build_expression,
|
||||
"jax": _jax.build_expression,
|
||||
}
|
||||
|
||||
EVAL_CONSTS_BACKENDS = {
|
||||
"tensorflow": _tensorflow.evaluate_constants,
|
||||
"theano": _theano.evaluate_constants,
|
||||
"cupy": _cupy.evaluate_constants,
|
||||
"torch": _torch.evaluate_constants,
|
||||
"jax": _jax.evaluate_constants,
|
||||
}
|
||||
|
||||
|
||||
def build_expression(backend, arrays, expr):
|
||||
"""Build an expression, based on ``expr`` and initial arrays ``arrays``,
|
||||
that evaluates using backend ``backend``.
|
||||
"""
|
||||
return CONVERT_BACKENDS[backend](arrays, expr)
|
||||
|
||||
|
||||
def evaluate_constants(backend, arrays, expr):
|
||||
"""Convert constant arrays to the correct backend, and perform as much of
|
||||
the contraction of ``expr`` with these as possible.
|
||||
"""
|
||||
return EVAL_CONSTS_BACKENDS[backend](arrays, expr)
|
||||
|
||||
|
||||
def has_backend(backend: str) -> bool:
|
||||
"""Checks if the backend is known."""
|
||||
return backend.lower() in CONVERT_BACKENDS
|
||||
@@ -0,0 +1,45 @@
|
||||
"""Required functions for optimized contractions of numpy arrays using jax."""
|
||||
|
||||
from opt_einsum.sharing import to_backend_cache_wrap
|
||||
|
||||
__all__ = ["build_expression", "evaluate_constants"]
|
||||
|
||||
_JAX = None
|
||||
|
||||
|
||||
def _get_jax_and_to_jax():
|
||||
global _JAX
|
||||
if _JAX is None:
|
||||
import jax # type: ignore
|
||||
|
||||
@to_backend_cache_wrap
|
||||
@jax.jit
|
||||
def to_jax(x):
|
||||
return x
|
||||
|
||||
_JAX = jax, to_jax
|
||||
|
||||
return _JAX
|
||||
|
||||
|
||||
def build_expression(_, expr): # pragma: no cover
|
||||
"""Build a jax function based on ``arrays`` and ``expr``."""
|
||||
jax, _ = _get_jax_and_to_jax()
|
||||
|
||||
jax_expr = jax.jit(expr._contract)
|
||||
|
||||
def jax_contract(*arrays):
|
||||
import numpy as np # type: ignore
|
||||
|
||||
return np.asarray(jax_expr(arrays))
|
||||
|
||||
return jax_contract
|
||||
|
||||
|
||||
def evaluate_constants(const_arrays, expr): # pragma: no cover
|
||||
"""Convert constant arguments to jax arrays, and perform any possible
|
||||
constant contractions.
|
||||
"""
|
||||
jax, to_jax = _get_jax_and_to_jax()
|
||||
|
||||
return expr(*[to_jax(x) for x in const_arrays], backend="jax", evaluate_constants=True)
|
||||
@@ -0,0 +1,59 @@
|
||||
"""Functions for performing contractions with array elements which are objects."""
|
||||
|
||||
import functools
|
||||
import operator
|
||||
|
||||
from opt_einsum.typing import ArrayType
|
||||
|
||||
|
||||
def object_einsum(eq: str, *arrays: ArrayType) -> ArrayType:
|
||||
"""A ``einsum`` implementation for ``numpy`` arrays with object dtype.
|
||||
The loop is performed in python, meaning the objects themselves need
|
||||
only to implement ``__mul__`` and ``__add__`` for the contraction to be
|
||||
computed. This may be useful when, for example, computing expressions of
|
||||
tensors with symbolic elements, but note it will be very slow when compared
|
||||
to ``numpy.einsum`` and numeric data types!
|
||||
|
||||
Parameters
|
||||
----------
|
||||
eq : str
|
||||
The contraction string, should specify output.
|
||||
arrays : sequence of arrays
|
||||
These can be any indexable arrays as long as addition and
|
||||
multiplication is defined on the elements.
|
||||
|
||||
Returns:
|
||||
-------
|
||||
out : numpy.ndarray
|
||||
The output tensor, with ``dtype=object``.
|
||||
"""
|
||||
import numpy as np # type: ignore
|
||||
|
||||
# when called by ``opt_einsum`` we will always be given a full eq
|
||||
lhs, output = eq.split("->")
|
||||
inputs = lhs.split(",")
|
||||
|
||||
sizes = {}
|
||||
for term, array in zip(inputs, arrays):
|
||||
for k, d in zip(term, array.shape):
|
||||
sizes[k] = d
|
||||
|
||||
out_size = tuple(sizes[k] for k in output)
|
||||
out = np.empty(out_size, dtype=object)
|
||||
|
||||
inner = tuple(k for k in sizes if k not in output)
|
||||
inner_size = tuple(sizes[k] for k in inner)
|
||||
|
||||
for coo_o in np.ndindex(*out_size):
|
||||
coord = dict(zip(output, coo_o))
|
||||
|
||||
def gen_inner_sum():
|
||||
for coo_i in np.ndindex(*inner_size):
|
||||
coord.update(dict(zip(inner, coo_i)))
|
||||
locs = (tuple(coord[k] for k in term) for term in inputs)
|
||||
elements = (array[loc] for array, loc in zip(arrays, locs))
|
||||
yield functools.reduce(operator.mul, elements)
|
||||
|
||||
out[coo_o] = functools.reduce(operator.add, gen_inner_sum())
|
||||
|
||||
return out
|
||||
@@ -0,0 +1,123 @@
|
||||
"""Required functions for optimized contractions of numpy arrays using tensorflow."""
|
||||
|
||||
from opt_einsum.helpers import has_array_interface
|
||||
from opt_einsum.sharing import to_backend_cache_wrap
|
||||
|
||||
__all__ = ["to_tensorflow", "build_expression", "evaluate_constants"]
|
||||
|
||||
_CACHED_TF_DEVICE = None
|
||||
|
||||
|
||||
def _get_tensorflow_and_device():
|
||||
global _CACHED_TF_DEVICE
|
||||
|
||||
if _CACHED_TF_DEVICE is None:
|
||||
import tensorflow as tf # type: ignore
|
||||
|
||||
try:
|
||||
eager = tf.executing_eagerly()
|
||||
except AttributeError:
|
||||
try:
|
||||
eager = tf.contrib.eager.in_eager_mode()
|
||||
except AttributeError:
|
||||
eager = False
|
||||
|
||||
device = tf.test.gpu_device_name()
|
||||
if not device:
|
||||
device = "cpu"
|
||||
|
||||
_CACHED_TF_DEVICE = tf, device, eager
|
||||
|
||||
return _CACHED_TF_DEVICE
|
||||
|
||||
|
||||
@to_backend_cache_wrap(constants=True)
|
||||
def to_tensorflow(array, constant=False):
|
||||
"""Convert a numpy array to a ``tensorflow.placeholder`` instance."""
|
||||
tf, device, eager = _get_tensorflow_and_device()
|
||||
|
||||
if eager:
|
||||
if has_array_interface(array):
|
||||
with tf.device(device):
|
||||
return tf.convert_to_tensor(array)
|
||||
|
||||
return array
|
||||
|
||||
if has_array_interface(array):
|
||||
if constant:
|
||||
return tf.convert_to_tensor(array)
|
||||
|
||||
return tf.placeholder(array.dtype, array.shape)
|
||||
|
||||
return array
|
||||
|
||||
|
||||
# Standard graph mode
|
||||
|
||||
|
||||
def build_expression_graph(arrays, expr):
|
||||
"""Build a tensorflow function based on ``arrays`` and ``expr``."""
|
||||
tf, _, _ = _get_tensorflow_and_device()
|
||||
|
||||
placeholders = [to_tensorflow(array) for array in arrays]
|
||||
graph = expr._contract(placeholders, backend="tensorflow")
|
||||
|
||||
def tensorflow_contract(*arrays):
|
||||
session = tf.get_default_session()
|
||||
# only want to feed placeholders - constant tensors already have values
|
||||
feed_dict = {p: a for p, a in zip(placeholders, arrays) if p.op.type == "Placeholder"}
|
||||
return session.run(graph, feed_dict=feed_dict)
|
||||
|
||||
return tensorflow_contract
|
||||
|
||||
|
||||
def evaluate_constants_graph(const_arrays, expr):
|
||||
"""Convert constant arguments to tensorflow constants, and perform any
|
||||
possible constant contractions. Requires evaluating a tensorflow graph.
|
||||
"""
|
||||
tf, _, _ = _get_tensorflow_and_device()
|
||||
|
||||
# compute the partial graph of new inputs
|
||||
const_arrays = [to_tensorflow(x, constant=True) for x in const_arrays]
|
||||
new_ops, new_contraction_list = expr(*const_arrays, backend="tensorflow", evaluate_constants=True)
|
||||
|
||||
# evaluate the new inputs and convert back to tensorflow, maintaining None as non-consts
|
||||
session = tf.get_default_session()
|
||||
new_consts = iter(session.run([x for x in new_ops if x is not None]))
|
||||
new_ops = [None if x is None else to_tensorflow(next(new_consts), constant=True) for x in new_ops]
|
||||
|
||||
return new_ops, new_contraction_list
|
||||
|
||||
|
||||
# Eager execution mode
|
||||
|
||||
|
||||
def build_expression_eager(_, expr):
|
||||
"""Build a eager tensorflow function based on ``arrays`` and ``expr``."""
|
||||
|
||||
def tensorflow_eager_contract(*arrays):
|
||||
return expr._contract([to_tensorflow(x) for x in arrays], backend="tensorflow").numpy()
|
||||
|
||||
return tensorflow_eager_contract
|
||||
|
||||
|
||||
def evaluate_constants_eager(const_arrays, expr):
|
||||
"""Convert constant arguments to tensorflow_eager arrays, and perform any
|
||||
possible constant contractions.
|
||||
"""
|
||||
return expr(*[to_tensorflow(x) for x in const_arrays], backend="tensorflow", evaluate_constants=True)
|
||||
|
||||
|
||||
# Dispatch to eager or graph mode
|
||||
|
||||
|
||||
def build_expression(arrays, expr):
|
||||
_, _, eager = _get_tensorflow_and_device()
|
||||
fn = build_expression_eager if eager else build_expression_graph
|
||||
return fn(arrays, expr)
|
||||
|
||||
|
||||
def evaluate_constants(const_arrays, expr):
|
||||
_, _, eager = _get_tensorflow_and_device()
|
||||
fn = evaluate_constants_eager if eager else evaluate_constants_graph
|
||||
return fn(const_arrays, expr)
|
||||
@@ -0,0 +1,48 @@
|
||||
"""Required functions for optimized contractions of numpy arrays using theano."""
|
||||
|
||||
from opt_einsum.helpers import has_array_interface
|
||||
from opt_einsum.sharing import to_backend_cache_wrap
|
||||
|
||||
__all__ = ["to_theano", "build_expression", "evaluate_constants"]
|
||||
|
||||
|
||||
@to_backend_cache_wrap(constants=True)
|
||||
def to_theano(array, constant=False):
|
||||
"""Convert a numpy array to ``theano.tensor.TensorType`` instance."""
|
||||
import theano # type: ignore
|
||||
|
||||
if has_array_interface(array):
|
||||
if constant:
|
||||
return theano.tensor.constant(array)
|
||||
|
||||
return theano.tensor.TensorType(dtype=array.dtype, broadcastable=[False] * len(array.shape))()
|
||||
|
||||
return array
|
||||
|
||||
|
||||
def build_expression(arrays, expr):
|
||||
"""Build a theano function based on ``arrays`` and ``expr``."""
|
||||
import theano
|
||||
|
||||
in_vars = [to_theano(array) for array in arrays]
|
||||
out_var = expr._contract(in_vars, backend="theano")
|
||||
|
||||
# don't supply constants to graph
|
||||
graph_ins = [x for x in in_vars if not isinstance(x, theano.tensor.TensorConstant)]
|
||||
graph = theano.function(graph_ins, out_var)
|
||||
|
||||
def theano_contract(*arrays):
|
||||
return graph(*[x for x in arrays if not isinstance(x, theano.tensor.TensorConstant)])
|
||||
|
||||
return theano_contract
|
||||
|
||||
|
||||
def evaluate_constants(const_arrays, expr):
|
||||
# compute the partial graph of new inputs
|
||||
const_arrays = [to_theano(x, constant=True) for x in const_arrays]
|
||||
new_ops, new_contraction_list = expr(*const_arrays, backend="theano", evaluate_constants=True)
|
||||
|
||||
# evaluate the new inputs and convert to theano shared tensors
|
||||
new_ops = [None if x is None else to_theano(x.eval(), constant=True) for x in new_ops]
|
||||
|
||||
return new_ops, new_contraction_list
|
||||
@@ -0,0 +1,130 @@
|
||||
"""Required functions for optimized contractions of numpy arrays using pytorch."""
|
||||
|
||||
from opt_einsum.helpers import has_array_interface
|
||||
from opt_einsum.parser import convert_to_valid_einsum_chars
|
||||
from opt_einsum.sharing import to_backend_cache_wrap
|
||||
|
||||
__all__ = [
|
||||
"transpose",
|
||||
"einsum",
|
||||
"tensordot",
|
||||
"to_torch",
|
||||
"build_expression",
|
||||
"evaluate_constants",
|
||||
]
|
||||
|
||||
_TORCH_DEVICE = None
|
||||
_TORCH_HAS_TENSORDOT = None
|
||||
|
||||
_torch_symbols_base = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
|
||||
|
||||
def _get_torch_and_device():
|
||||
global _TORCH_DEVICE
|
||||
global _TORCH_HAS_TENSORDOT
|
||||
|
||||
if _TORCH_DEVICE is None:
|
||||
import torch # type: ignore
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
_TORCH_DEVICE = torch, device
|
||||
_TORCH_HAS_TENSORDOT = hasattr(torch, "tensordot")
|
||||
|
||||
return _TORCH_DEVICE
|
||||
|
||||
|
||||
def transpose(a, axes):
|
||||
"""Normal torch transpose is only valid for 2D matrices."""
|
||||
return a.permute(*axes)
|
||||
|
||||
|
||||
def einsum(equation, *operands, **kwargs):
|
||||
"""Variadic version of torch.einsum to match numpy api."""
|
||||
# rename symbols to support PyTorch 0.4.1 and earlier,
|
||||
# which allow only symbols a-z.
|
||||
equation = convert_to_valid_einsum_chars(equation)
|
||||
|
||||
torch, _ = _get_torch_and_device()
|
||||
return torch.einsum(equation, operands)
|
||||
|
||||
|
||||
def tensordot(x, y, axes=2):
|
||||
"""Simple translation of tensordot syntax to einsum."""
|
||||
torch, _ = _get_torch_and_device()
|
||||
|
||||
if _TORCH_HAS_TENSORDOT:
|
||||
return torch.tensordot(x, y, dims=axes)
|
||||
|
||||
xnd = x.ndimension()
|
||||
ynd = y.ndimension()
|
||||
|
||||
# convert int argument to (list[int], list[int])
|
||||
if isinstance(axes, int):
|
||||
axes = range(xnd - axes, xnd), range(axes)
|
||||
|
||||
# convert (int, int) to (list[int], list[int])
|
||||
if isinstance(axes[0], int):
|
||||
axes = (axes[0],), axes[1]
|
||||
if isinstance(axes[1], int):
|
||||
axes = axes[0], (axes[1],)
|
||||
|
||||
# initialize empty indices
|
||||
x_ix = [None] * xnd
|
||||
y_ix = [None] * ynd
|
||||
out_ix = []
|
||||
|
||||
# fill in repeated indices
|
||||
available_ix = iter(_torch_symbols_base)
|
||||
for ax1, ax2 in zip(*axes):
|
||||
repeat = next(available_ix)
|
||||
x_ix[ax1] = repeat
|
||||
y_ix[ax2] = repeat
|
||||
|
||||
# fill in the rest, and maintain output order
|
||||
for i in range(xnd):
|
||||
if x_ix[i] is None:
|
||||
leave = next(available_ix)
|
||||
x_ix[i] = leave
|
||||
out_ix.append(leave)
|
||||
for i in range(ynd):
|
||||
if y_ix[i] is None:
|
||||
leave = next(available_ix)
|
||||
y_ix[i] = leave
|
||||
out_ix.append(leave)
|
||||
|
||||
# form full string and contract!
|
||||
einsum_str = "{},{}->{}".format(*map("".join, (x_ix, y_ix, out_ix)))
|
||||
return einsum(einsum_str, x, y)
|
||||
|
||||
|
||||
@to_backend_cache_wrap
|
||||
def to_torch(array):
|
||||
torch, device = _get_torch_and_device()
|
||||
|
||||
if has_array_interface(array):
|
||||
return torch.from_numpy(array).to(device)
|
||||
|
||||
return array
|
||||
|
||||
|
||||
def build_expression(_, expr): # pragma: no cover
|
||||
"""Build a torch function based on ``arrays`` and ``expr``."""
|
||||
|
||||
def torch_contract(*arrays):
|
||||
torch_arrays = [to_torch(x) for x in arrays]
|
||||
torch_out = expr._contract(torch_arrays, backend="torch")
|
||||
|
||||
if torch_out.device.type == "cpu":
|
||||
return torch_out.numpy()
|
||||
|
||||
return torch_out.cpu().numpy()
|
||||
|
||||
return torch_contract
|
||||
|
||||
|
||||
def evaluate_constants(const_arrays, expr):
|
||||
"""Convert constant arguments to torch, and perform any possible constant
|
||||
contractions.
|
||||
"""
|
||||
const_arrays = [to_torch(x) for x in const_arrays]
|
||||
return expr(*const_arrays, backend="torch", evaluate_constants=True)
|
||||
@@ -0,0 +1,122 @@
|
||||
"""Determines if a contraction can use BLAS or not."""
|
||||
|
||||
from typing import List, Sequence, Tuple, Union
|
||||
|
||||
from opt_einsum.typing import ArrayIndexType
|
||||
|
||||
__all__ = ["can_blas"]
|
||||
|
||||
|
||||
def can_blas(
|
||||
inputs: List[str],
|
||||
result: str,
|
||||
idx_removed: ArrayIndexType,
|
||||
shapes: Union[Sequence[Tuple[int]], None] = None,
|
||||
) -> Union[str, bool]:
|
||||
"""Checks if we can use a BLAS call.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inputs : list of str
|
||||
Specifies the subscripts for summation.
|
||||
result : str
|
||||
Resulting summation.
|
||||
idx_removed : set
|
||||
Indices that are removed in the summation
|
||||
shapes : sequence of tuple[int], optional
|
||||
If given, check also that none of the indices are broadcast dimensions.
|
||||
|
||||
Returns:
|
||||
-------
|
||||
type : str or bool
|
||||
The type of BLAS call to be used or False if none.
|
||||
|
||||
Notes:
|
||||
-----
|
||||
We assume several operations are not efficient such as a transposed
|
||||
DDOT, therefore 'ijk,jki->' should prefer einsum. These return the blas
|
||||
type appended with "/EINSUM" to differentiate when they can still be done
|
||||
with tensordot if required, e.g. when a backend has no einsum.
|
||||
|
||||
Examples:
|
||||
--------
|
||||
>>> can_blas(['ij', 'jk'], 'ik', set('j'))
|
||||
'GEMM'
|
||||
|
||||
>>> can_blas(['ijj', 'jk'], 'ik', set('j'))
|
||||
False
|
||||
|
||||
>>> can_blas(['ab', 'cd'], 'abcd', set())
|
||||
'OUTER/EINSUM'
|
||||
|
||||
>>> # looks like GEMM but actually 'j' is broadcast:
|
||||
>>> can_blas(['ij', 'jk'], 'ik', set('j'), shapes=[(4, 1), (5, 6)])
|
||||
False
|
||||
"""
|
||||
# Can only do two
|
||||
if len(inputs) != 2:
|
||||
return False
|
||||
|
||||
input_left, input_right = inputs
|
||||
|
||||
for c in set(input_left + input_right):
|
||||
# can't deal with repeated indices on same input or more than 2 total
|
||||
nl, nr = input_left.count(c), input_right.count(c)
|
||||
if (nl > 1) or (nr > 1) or (nl + nr > 2):
|
||||
return False
|
||||
|
||||
# can't do implicit summation or dimension collapse e.g.
|
||||
# "ab,bc->c" (implicitly sum over 'a')
|
||||
# "ab,ca->ca" (take diagonal of 'a')
|
||||
if nl + nr - 1 == int(c in result):
|
||||
return False
|
||||
|
||||
# check for broadcast indices e.g:
|
||||
# "ij,jk->ik" (but one of the 'j' dimensions is broadcast up)
|
||||
if shapes is not None:
|
||||
for c in idx_removed:
|
||||
if shapes[0][input_left.find(c)] != shapes[1][input_right.find(c)]:
|
||||
return False
|
||||
|
||||
# Prefer einsum if not removing indices
|
||||
# (N.B. tensordot outer faster for large arrays?)
|
||||
if len(idx_removed) == 0:
|
||||
return "OUTER/EINSUM"
|
||||
|
||||
# Build a few temporaries
|
||||
sets = [set(x) for x in inputs]
|
||||
keep_left = sets[0] - idx_removed
|
||||
keep_right = sets[1] - idx_removed
|
||||
rs = len(idx_removed)
|
||||
|
||||
# DDOT
|
||||
if inputs[0] == inputs[1]:
|
||||
return "DOT"
|
||||
|
||||
# DDOT does not make sense if you have to transpose - prefer einsum
|
||||
elif sets[0] == sets[1]:
|
||||
return "DOT/EINSUM"
|
||||
|
||||
# GEMM no transpose
|
||||
if input_left[-rs:] == input_right[:rs]:
|
||||
return "GEMM"
|
||||
|
||||
# GEMM transpose both
|
||||
elif input_left[:rs] == input_right[-rs:]:
|
||||
return "GEMM"
|
||||
|
||||
# GEMM transpose right
|
||||
elif input_left[-rs:] == input_right[-rs:]:
|
||||
return "GEMM"
|
||||
|
||||
# GEMM transpose left
|
||||
elif input_left[:rs] == input_right[:rs]:
|
||||
return "GEMM"
|
||||
|
||||
# Einsum is faster than vectordot if we have to copy
|
||||
elif (len(keep_left) == 0) or (len(keep_right) == 0):
|
||||
return "GEMV/EINSUM"
|
||||
|
||||
# Conventional tensordot
|
||||
else:
|
||||
return "TDOT"
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,151 @@
|
||||
"""Contains helper functions for opt_einsum testing scripts."""
|
||||
|
||||
from typing import Any, Collection, Dict, FrozenSet, Iterable, List, Tuple, overload
|
||||
|
||||
from opt_einsum.typing import ArrayIndexType, ArrayType
|
||||
|
||||
__all__ = ["compute_size_by_dict", "find_contraction", "flop_count"]
|
||||
|
||||
_valid_chars = "abcdefghijklmopqABC"
|
||||
_sizes = [2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3, 2, 5, 7, 4, 3, 2, 3, 4]
|
||||
_default_dim_dict = dict(zip(_valid_chars, _sizes))
|
||||
|
||||
|
||||
@overload
|
||||
def compute_size_by_dict(indices: Iterable[int], idx_dict: List[int]) -> int: ...
|
||||
|
||||
|
||||
@overload
|
||||
def compute_size_by_dict(indices: Collection[str], idx_dict: Dict[str, int]) -> int: ...
|
||||
|
||||
|
||||
def compute_size_by_dict(indices: Any, idx_dict: Any) -> int:
|
||||
"""Computes the product of the elements in indices based on the dictionary
|
||||
idx_dict.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
indices : iterable
|
||||
Indices to base the product on.
|
||||
idx_dict : dictionary
|
||||
Dictionary of index _sizes
|
||||
|
||||
Returns:
|
||||
-------
|
||||
ret : int
|
||||
The resulting product.
|
||||
|
||||
Examples:
|
||||
--------
|
||||
>>> compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5})
|
||||
90
|
||||
|
||||
"""
|
||||
ret = 1
|
||||
for i in indices: # lgtm [py/iteration-string-and-sequence]
|
||||
ret *= idx_dict[i]
|
||||
return ret
|
||||
|
||||
|
||||
def find_contraction(
|
||||
positions: Collection[int],
|
||||
input_sets: List[ArrayIndexType],
|
||||
output_set: ArrayIndexType,
|
||||
) -> Tuple[FrozenSet[str], List[ArrayIndexType], ArrayIndexType, ArrayIndexType]:
|
||||
"""Finds the contraction for a given set of input and output sets.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
positions : iterable
|
||||
Integer positions of terms used in the contraction.
|
||||
input_sets : list
|
||||
List of sets that represent the lhs side of the einsum subscript
|
||||
output_set : set
|
||||
Set that represents the rhs side of the overall einsum subscript
|
||||
|
||||
Returns:
|
||||
-------
|
||||
new_result : set
|
||||
The indices of the resulting contraction
|
||||
remaining : list
|
||||
List of sets that have not been contracted, the new set is appended to
|
||||
the end of this list
|
||||
idx_removed : set
|
||||
Indices removed from the entire contraction
|
||||
idx_contraction : set
|
||||
The indices used in the current contraction
|
||||
|
||||
Examples:
|
||||
--------
|
||||
# A simple dot product test case
|
||||
>>> pos = (0, 1)
|
||||
>>> isets = [set('ab'), set('bc')]
|
||||
>>> oset = set('ac')
|
||||
>>> find_contraction(pos, isets, oset)
|
||||
({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'})
|
||||
|
||||
# A more complex case with additional terms in the contraction
|
||||
>>> pos = (0, 2)
|
||||
>>> isets = [set('abd'), set('ac'), set('bdc')]
|
||||
>>> oset = set('ac')
|
||||
>>> find_contraction(pos, isets, oset)
|
||||
({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'})
|
||||
"""
|
||||
remaining = list(input_sets)
|
||||
inputs = (remaining.pop(i) for i in sorted(positions, reverse=True))
|
||||
idx_contract = frozenset.union(*inputs)
|
||||
idx_remain = output_set.union(*remaining)
|
||||
|
||||
new_result = idx_remain & idx_contract
|
||||
idx_removed = idx_contract - new_result
|
||||
remaining.append(new_result)
|
||||
|
||||
return new_result, remaining, idx_removed, idx_contract
|
||||
|
||||
|
||||
def flop_count(
|
||||
idx_contraction: Collection[str],
|
||||
inner: bool,
|
||||
num_terms: int,
|
||||
size_dictionary: Dict[str, int],
|
||||
) -> int:
|
||||
"""Computes the number of FLOPS in the contraction.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
idx_contraction : iterable
|
||||
The indices involved in the contraction
|
||||
inner : bool
|
||||
Does this contraction require an inner product?
|
||||
num_terms : int
|
||||
The number of terms in a contraction
|
||||
size_dictionary : dict
|
||||
The size of each of the indices in idx_contraction
|
||||
|
||||
Returns:
|
||||
-------
|
||||
flop_count : int
|
||||
The total number of FLOPS required for the contraction.
|
||||
|
||||
Examples:
|
||||
--------
|
||||
>>> flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5})
|
||||
30
|
||||
|
||||
>>> flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5})
|
||||
60
|
||||
|
||||
"""
|
||||
overall_size = compute_size_by_dict(idx_contraction, size_dictionary)
|
||||
op_factor = max(1, num_terms - 1)
|
||||
if inner:
|
||||
op_factor += 1
|
||||
|
||||
return overall_size * op_factor
|
||||
|
||||
|
||||
def has_array_interface(array: ArrayType) -> ArrayType:
|
||||
if hasattr(array, "__array_interface__"):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
@@ -0,0 +1,415 @@
|
||||
"""A functionally equivalent parser of the numpy.einsum input parser."""
|
||||
|
||||
import itertools
|
||||
from typing import Any, Dict, Iterator, List, Sequence, Tuple
|
||||
|
||||
from opt_einsum.typing import ArrayType, TensorShapeType
|
||||
|
||||
__all__ = [
|
||||
"is_valid_einsum_char",
|
||||
"has_valid_einsum_chars_only",
|
||||
"get_symbol",
|
||||
"get_shape",
|
||||
"gen_unused_symbols",
|
||||
"convert_to_valid_einsum_chars",
|
||||
"alpha_canonicalize",
|
||||
"find_output_str",
|
||||
"find_output_shape",
|
||||
"possibly_convert_to_numpy",
|
||||
"parse_einsum_input",
|
||||
]
|
||||
|
||||
_einsum_symbols_base = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
|
||||
|
||||
def is_valid_einsum_char(x: str) -> bool:
|
||||
"""Check if the character ``x`` is valid for numpy einsum.
|
||||
|
||||
**Examples:**
|
||||
|
||||
```python
|
||||
is_valid_einsum_char("a")
|
||||
#> True
|
||||
|
||||
is_valid_einsum_char("Ǵ")
|
||||
#> False
|
||||
```
|
||||
"""
|
||||
return (x in _einsum_symbols_base) or (x in ",->.")
|
||||
|
||||
|
||||
def has_valid_einsum_chars_only(einsum_str: str) -> bool:
|
||||
"""Check if ``einsum_str`` contains only valid characters for numpy einsum.
|
||||
|
||||
**Examples:**
|
||||
|
||||
```python
|
||||
has_valid_einsum_chars_only("abAZ")
|
||||
#> True
|
||||
|
||||
has_valid_einsum_chars_only("Över")
|
||||
#> False
|
||||
```
|
||||
"""
|
||||
return all(map(is_valid_einsum_char, einsum_str))
|
||||
|
||||
|
||||
def get_symbol(i: int) -> str:
|
||||
"""Get the symbol corresponding to int ``i`` - runs through the usual 52
|
||||
letters before resorting to unicode characters, starting at ``chr(192)`` and skipping surrogates.
|
||||
|
||||
**Examples:**
|
||||
|
||||
```python
|
||||
get_symbol(2)
|
||||
#> 'c'
|
||||
|
||||
get_symbol(200)
|
||||
#> 'Ŕ'
|
||||
|
||||
get_symbol(20000)
|
||||
#> '京'
|
||||
```
|
||||
"""
|
||||
if i < 52:
|
||||
return _einsum_symbols_base[i]
|
||||
elif i >= 55296:
|
||||
# Skip chr(57343) - chr(55296) as surrogates
|
||||
return chr(i + 2048)
|
||||
else:
|
||||
return chr(i + 140)
|
||||
|
||||
|
||||
def gen_unused_symbols(used: str, n: int) -> Iterator[str]:
|
||||
"""Generate ``n`` symbols that are not already in ``used``.
|
||||
|
||||
**Examples:**
|
||||
```python
|
||||
list(oe.parser.gen_unused_symbols("abd", 2))
|
||||
#> ['c', 'e']
|
||||
```
|
||||
"""
|
||||
i = cnt = 0
|
||||
while cnt < n:
|
||||
s = get_symbol(i)
|
||||
i += 1
|
||||
if s in used:
|
||||
continue
|
||||
yield s
|
||||
cnt += 1
|
||||
|
||||
|
||||
def convert_to_valid_einsum_chars(einsum_str: str) -> str:
|
||||
"""Convert the str ``einsum_str`` to contain only the alphabetic characters
|
||||
valid for numpy einsum. If there are too many symbols, let the backend
|
||||
throw an error.
|
||||
|
||||
Examples:
|
||||
--------
|
||||
>>> oe.parser.convert_to_valid_einsum_chars("Ĥěļļö")
|
||||
'cbdda'
|
||||
"""
|
||||
symbols = sorted(set(einsum_str) - set(",->"))
|
||||
replacer = {x: get_symbol(i) for i, x in enumerate(symbols)}
|
||||
return "".join(replacer.get(x, x) for x in einsum_str)
|
||||
|
||||
|
||||
def alpha_canonicalize(equation: str) -> str:
|
||||
"""Alpha convert an equation in an order-independent canonical way.
|
||||
|
||||
Examples:
|
||||
--------
|
||||
>>> oe.parser.alpha_canonicalize("dcba")
|
||||
'abcd'
|
||||
|
||||
>>> oe.parser.alpha_canonicalize("Ĥěļļö")
|
||||
'abccd'
|
||||
"""
|
||||
rename: Dict[str, str] = {}
|
||||
for name in equation:
|
||||
if name in ".,->":
|
||||
continue
|
||||
if name not in rename:
|
||||
rename[name] = get_symbol(len(rename))
|
||||
return "".join(rename.get(x, x) for x in equation)
|
||||
|
||||
|
||||
def find_output_str(subscripts: str) -> str:
|
||||
"""Find the output string for the inputs ``subscripts`` under canonical einstein summation rules.
|
||||
That is, repeated indices are summed over by default.
|
||||
|
||||
Examples:
|
||||
--------
|
||||
>>> oe.parser.find_output_str("ab,bc")
|
||||
'ac'
|
||||
|
||||
>>> oe.parser.find_output_str("a,b")
|
||||
'ab'
|
||||
|
||||
>>> oe.parser.find_output_str("a,a,b,b")
|
||||
''
|
||||
"""
|
||||
tmp_subscripts = subscripts.replace(",", "")
|
||||
return "".join(s for s in sorted(set(tmp_subscripts)) if tmp_subscripts.count(s) == 1)
|
||||
|
||||
|
||||
def find_output_shape(inputs: List[str], shapes: List[TensorShapeType], output: str) -> TensorShapeType:
|
||||
"""Find the output shape for given inputs, shapes and output string, taking
|
||||
into account broadcasting.
|
||||
|
||||
Examples:
|
||||
--------
|
||||
>>> oe.parser.find_output_shape(["ab", "bc"], [(2, 3), (3, 4)], "ac")
|
||||
(2, 4)
|
||||
|
||||
# Broadcasting is accounted for
|
||||
>>> oe.parser.find_output_shape(["a", "a"], [(4, ), (1, )], "a")
|
||||
(4,)
|
||||
"""
|
||||
return tuple(max(shape[loc] for shape, loc in zip(shapes, [x.find(c) for x in inputs]) if loc >= 0) for c in output)
|
||||
|
||||
|
||||
_BaseTypes = (bool, int, float, complex, str, bytes)
|
||||
|
||||
|
||||
def get_shape(x: Any) -> TensorShapeType:
|
||||
"""Get the shape of the array-like object `x`. If `x` is not array-like, raise an error.
|
||||
|
||||
Array-like objects are those that have a `shape` attribute, are sequences of BaseTypes, or are BaseTypes.
|
||||
BaseTypes are defined as `bool`, `int`, `float`, `complex`, `str`, and `bytes`.
|
||||
"""
|
||||
if hasattr(x, "shape"):
|
||||
return x.shape
|
||||
elif isinstance(x, _BaseTypes):
|
||||
return ()
|
||||
elif isinstance(x, Sequence):
|
||||
shape = []
|
||||
while isinstance(x, Sequence) and not isinstance(x, _BaseTypes):
|
||||
shape.append(len(x))
|
||||
x = x[0]
|
||||
return tuple(shape)
|
||||
else:
|
||||
raise ValueError(f"Cannot determine the shape of {x}, can only determine the shape of array-like objects.")
|
||||
|
||||
|
||||
def possibly_convert_to_numpy(x: Any) -> Any:
|
||||
"""Convert things without a 'shape' to ndarrays, but leave everything else.
|
||||
|
||||
Examples:
|
||||
--------
|
||||
>>> oe.parser.possibly_convert_to_numpy(5)
|
||||
array(5)
|
||||
|
||||
>>> oe.parser.possibly_convert_to_numpy([5, 3])
|
||||
array([5, 3])
|
||||
|
||||
>>> oe.parser.possibly_convert_to_numpy(np.array([5, 3]))
|
||||
array([5, 3])
|
||||
|
||||
# Any class with a shape is passed through
|
||||
>>> class Shape:
|
||||
... def __init__(self, shape):
|
||||
... self.shape = shape
|
||||
...
|
||||
|
||||
>>> myshape = Shape((5, 5))
|
||||
>>> oe.parser.possibly_convert_to_numpy(myshape)
|
||||
<__main__.Shape object at 0x10f850710>
|
||||
"""
|
||||
if not hasattr(x, "shape"):
|
||||
try:
|
||||
import numpy as np # type: ignore
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"numpy is required to convert non-array objects to arrays. This function will be deprecated in the future."
|
||||
)
|
||||
|
||||
return np.asanyarray(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def convert_subscripts(old_sub: List[Any], symbol_map: Dict[Any, Any]) -> str:
|
||||
"""Convert user custom subscripts list to subscript string according to `symbol_map`.
|
||||
|
||||
Examples:
|
||||
--------
|
||||
>>> oe.parser.convert_subscripts(['abc', 'def'], {'abc':'a', 'def':'b'})
|
||||
'ab'
|
||||
>>> oe.parser.convert_subscripts([Ellipsis, object], {object:'a'})
|
||||
'...a'
|
||||
"""
|
||||
new_sub = ""
|
||||
for s in old_sub:
|
||||
if s is Ellipsis:
|
||||
new_sub += "..."
|
||||
else:
|
||||
# no need to try/except here because symbol_map has already been checked
|
||||
new_sub += symbol_map[s]
|
||||
return new_sub
|
||||
|
||||
|
||||
def convert_interleaved_input(operands: Sequence[Any]) -> Tuple[str, Tuple[Any, ...]]:
|
||||
"""Convert 'interleaved' input to standard einsum input."""
|
||||
tmp_operands = list(operands)
|
||||
operand_list = []
|
||||
subscript_list = []
|
||||
for _ in range(len(operands) // 2):
|
||||
operand_list.append(tmp_operands.pop(0))
|
||||
subscript_list.append(tmp_operands.pop(0))
|
||||
|
||||
output_list = tmp_operands[-1] if len(tmp_operands) else None
|
||||
|
||||
# build a map from user symbols to single-character symbols based on `get_symbol`
|
||||
# The map retains the intrinsic order of user symbols
|
||||
try:
|
||||
# collect all user symbols
|
||||
symbol_set = set(itertools.chain.from_iterable(subscript_list))
|
||||
|
||||
# remove Ellipsis because it can not be compared with other objects
|
||||
symbol_set.discard(Ellipsis)
|
||||
|
||||
# build the map based on sorted user symbols, retaining the order we lost in the `set`
|
||||
symbol_map = {symbol: get_symbol(idx) for idx, symbol in enumerate(sorted(symbol_set))}
|
||||
|
||||
except TypeError: # unhashable or uncomparable object
|
||||
raise TypeError(
|
||||
"For this input type lists must contain either Ellipsis "
|
||||
"or hashable and comparable object (e.g. int, str)."
|
||||
)
|
||||
|
||||
subscripts = ",".join(convert_subscripts(sub, symbol_map) for sub in subscript_list)
|
||||
if output_list is not None:
|
||||
subscripts += "->"
|
||||
subscripts += convert_subscripts(output_list, symbol_map)
|
||||
|
||||
return subscripts, tuple(operand_list)
|
||||
|
||||
|
||||
def parse_einsum_input(operands: Any, shapes: bool = False) -> Tuple[str, str, List[ArrayType]]:
|
||||
"""A reproduction of einsum c side einsum parsing in python.
|
||||
|
||||
Parameters:
|
||||
operands: Intakes the same inputs as `contract_path`, but NOT the keyword args. The only
|
||||
supported keyword argument is:
|
||||
shapes: Whether ``parse_einsum_input`` should assume arrays (the default) or
|
||||
array shapes have been supplied.
|
||||
|
||||
Returns:
|
||||
input_strings: Parsed input strings
|
||||
output_string: Parsed output string
|
||||
operands: The operands to use in the numpy contraction
|
||||
|
||||
Examples:
|
||||
The operand list is simplified to reduce printing:
|
||||
|
||||
```python
|
||||
>>> a = np.random.rand(4, 4)
|
||||
>>> b = np.random.rand(4, 4, 4)
|
||||
>>> parse_einsum_input(('...a,...a->...', a, b))
|
||||
('za,xza', 'xz', [a, b])
|
||||
|
||||
>>> parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0]))
|
||||
('za,xza', 'xz', [a, b])
|
||||
```
|
||||
"""
|
||||
if len(operands) == 0:
|
||||
raise ValueError("No input operands")
|
||||
|
||||
if isinstance(operands[0], str):
|
||||
subscripts = operands[0].replace(" ", "")
|
||||
if shapes:
|
||||
if any(hasattr(o, "shape") for o in operands[1:]):
|
||||
raise ValueError(
|
||||
"shapes is set to True but given at least one operand looks like an array"
|
||||
" (at least one operand has a shape attribute). "
|
||||
)
|
||||
operands = operands[1:]
|
||||
else:
|
||||
subscripts, operands = convert_interleaved_input(operands)
|
||||
|
||||
if shapes:
|
||||
operand_shapes = operands
|
||||
else:
|
||||
operand_shapes = [get_shape(o) for o in operands]
|
||||
|
||||
# Check for proper "->"
|
||||
if ("-" in subscripts) or (">" in subscripts):
|
||||
invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1)
|
||||
if invalid or (subscripts.count("->") != 1):
|
||||
raise ValueError("Subscripts can only contain one '->'.")
|
||||
|
||||
# Parse ellipses
|
||||
if "." in subscripts:
|
||||
used = subscripts.replace(".", "").replace(",", "").replace("->", "")
|
||||
ellipse_inds = "".join(gen_unused_symbols(used, max(len(x) for x in operand_shapes)))
|
||||
longest = 0
|
||||
|
||||
# Do we have an output to account for?
|
||||
if "->" in subscripts:
|
||||
input_tmp, output_sub = subscripts.split("->")
|
||||
split_subscripts = input_tmp.split(",")
|
||||
out_sub = True
|
||||
else:
|
||||
split_subscripts = subscripts.split(",")
|
||||
out_sub = False
|
||||
|
||||
for num, sub in enumerate(split_subscripts):
|
||||
if "." in sub:
|
||||
if (sub.count(".") != 3) or (sub.count("...") != 1):
|
||||
raise ValueError("Invalid Ellipses.")
|
||||
|
||||
# Take into account numerical values
|
||||
if operand_shapes[num] == ():
|
||||
ellipse_count = 0
|
||||
else:
|
||||
ellipse_count = max(len(operand_shapes[num]), 1) - (len(sub) - 3)
|
||||
|
||||
if ellipse_count > longest:
|
||||
longest = ellipse_count
|
||||
|
||||
if ellipse_count < 0:
|
||||
raise ValueError("Ellipses lengths do not match.")
|
||||
elif ellipse_count == 0:
|
||||
split_subscripts[num] = sub.replace("...", "")
|
||||
else:
|
||||
split_subscripts[num] = sub.replace("...", ellipse_inds[-ellipse_count:])
|
||||
|
||||
subscripts = ",".join(split_subscripts)
|
||||
|
||||
# Figure out output ellipses
|
||||
if longest == 0:
|
||||
out_ellipse = ""
|
||||
else:
|
||||
out_ellipse = ellipse_inds[-longest:]
|
||||
|
||||
if out_sub:
|
||||
subscripts += "->" + output_sub.replace("...", out_ellipse)
|
||||
else:
|
||||
# Special care for outputless ellipses
|
||||
output_subscript = find_output_str(subscripts)
|
||||
normal_inds = "".join(sorted(set(output_subscript) - set(out_ellipse)))
|
||||
|
||||
subscripts += "->" + out_ellipse + normal_inds
|
||||
|
||||
# Build output string if does not exist
|
||||
if "->" in subscripts:
|
||||
input_subscripts, output_subscript = subscripts.split("->")
|
||||
else:
|
||||
input_subscripts, output_subscript = subscripts, find_output_str(subscripts)
|
||||
|
||||
# Make sure output subscripts are unique and in the input
|
||||
for char in output_subscript:
|
||||
if output_subscript.count(char) != 1:
|
||||
raise ValueError(f"Output character '{char}' appeared more than once in the output.")
|
||||
if char not in input_subscripts:
|
||||
raise ValueError(f"Output character '{char}' did not appear in the input")
|
||||
|
||||
# Make sure number operands is equivalent to the number of terms
|
||||
if len(input_subscripts.split(",")) != len(operands):
|
||||
raise ValueError(
|
||||
f"Number of einsum subscripts, {len(input_subscripts.split(','))}, must be equal to the "
|
||||
f"number of operands, {len(operands)}."
|
||||
)
|
||||
|
||||
return input_subscripts, output_subscript, operands
|
||||
@@ -0,0 +1,398 @@
|
||||
"""Support for random optimizers, including the random-greedy path."""
|
||||
|
||||
import functools
|
||||
import heapq
|
||||
import math
|
||||
import time
|
||||
from collections import deque
|
||||
from decimal import Decimal
|
||||
from random import choices as random_choices
|
||||
from random import seed as random_seed
|
||||
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
from opt_einsum import helpers, paths
|
||||
from opt_einsum.typing import ArrayIndexType, ArrayType, PathType
|
||||
|
||||
__all__ = ["RandomGreedy", "random_greedy", "random_greedy_128"]
|
||||
|
||||
|
||||
class RandomOptimizer(paths.PathOptimizer):
|
||||
"""Base class for running any random path finder that benefits
|
||||
from repeated calling, possibly in a parallel fashion. Custom random
|
||||
optimizers should subclass this, and the `setup` method should be
|
||||
implemented with the following signature:
|
||||
|
||||
```python
|
||||
def setup(self, inputs, output, size_dict):
|
||||
# custom preparation here ...
|
||||
return trial_fn, trial_args
|
||||
```
|
||||
|
||||
Where `trial_fn` itself should have the signature::
|
||||
|
||||
```python
|
||||
def trial_fn(r, *trial_args):
|
||||
# custom computation of path here
|
||||
return ssa_path, cost, size
|
||||
```
|
||||
|
||||
Where `r` is the run number and could for example be used to seed a
|
||||
random number generator. See `RandomGreedy` for an example.
|
||||
|
||||
|
||||
Parameters:
|
||||
max_repeats: The maximum number of repeat trials to have.
|
||||
max_time: The maximum amount of time to run the algorithm for.
|
||||
minimize: Whether to favour paths that minimize the total estimated flop-count or
|
||||
the size of the largest intermediate created.
|
||||
parallel: Whether to parallelize the random trials, by default `False`. If
|
||||
`True`, use a `concurrent.futures.ProcessPoolExecutor` with the same
|
||||
number of processes as cores. If an integer is specified, use that many
|
||||
processes instead. Finally, you can supply a custom executor-pool which
|
||||
should have an API matching that of the python 3 standard library
|
||||
module `concurrent.futures`. Namely, a `submit` method that returns
|
||||
`Future` objects, themselves with `result` and `cancel` methods.
|
||||
pre_dispatch: If running in parallel, how many jobs to pre-dispatch so as to avoid
|
||||
submitting all jobs at once. Should also be more than twice the number
|
||||
of workers to avoid under-subscription. Default: 128.
|
||||
|
||||
Attributes:
|
||||
path: The best path found so far.
|
||||
costs: The list of each trial's costs found so far.
|
||||
sizes: The list of each trial's largest intermediate size so far.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_repeats: int = 32,
|
||||
max_time: Optional[float] = None,
|
||||
minimize: str = "flops",
|
||||
parallel: Union[bool, Decimal, int] = False,
|
||||
pre_dispatch: int = 128,
|
||||
):
|
||||
if minimize not in ("flops", "size"):
|
||||
raise ValueError("`minimize` should be one of {'flops', 'size'}.")
|
||||
|
||||
self.max_repeats = max_repeats
|
||||
self.max_time = max_time
|
||||
self.minimize = minimize
|
||||
self.better = paths.get_better_fn(minimize)
|
||||
self._parallel: Union[bool, Decimal, int] = False
|
||||
self.parallel = parallel
|
||||
self.pre_dispatch = pre_dispatch
|
||||
|
||||
self.costs: List[int] = []
|
||||
self.sizes: List[int] = []
|
||||
self.best: Dict[str, Any] = {"flops": float("inf"), "size": float("inf")}
|
||||
|
||||
self._repeats_start = 0
|
||||
self._executor: Any
|
||||
self._futures: Any
|
||||
|
||||
@property
|
||||
def path(self) -> PathType:
|
||||
"""The best path found so far."""
|
||||
return paths.ssa_to_linear(self.best["ssa_path"])
|
||||
|
||||
@property
|
||||
def parallel(self) -> Union[bool, Decimal, int]:
|
||||
return self._parallel
|
||||
|
||||
@parallel.setter
|
||||
def parallel(self, parallel: Union[bool, Decimal, int]) -> None:
|
||||
# shutdown any previous executor if we are managing it
|
||||
if getattr(self, "_managing_executor", False):
|
||||
self._executor.shutdown()
|
||||
|
||||
self._parallel = parallel
|
||||
self._managing_executor = False
|
||||
|
||||
if parallel is False:
|
||||
self._executor = None
|
||||
return
|
||||
|
||||
if parallel is True:
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
self._executor = ProcessPoolExecutor()
|
||||
self._managing_executor = True
|
||||
return
|
||||
|
||||
if isinstance(parallel, (int, Decimal)):
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
self._executor = ProcessPoolExecutor(int(parallel))
|
||||
self._managing_executor = True
|
||||
return
|
||||
|
||||
# assume a pool-executor has been supplied
|
||||
self._executor = parallel
|
||||
|
||||
def _gen_results_parallel(self, repeats: Iterable[int], trial_fn: Any, args: Any) -> Generator[Any, None, None]:
|
||||
"""Lazily generate results from an executor without submitting all jobs at once."""
|
||||
self._futures = deque()
|
||||
|
||||
# the idea here is to submit at least ``pre_dispatch`` jobs *before* we
|
||||
# yield any results, then do both in tandem, before draining the queue
|
||||
for r in repeats:
|
||||
if len(self._futures) < self.pre_dispatch:
|
||||
self._futures.append(self._executor.submit(trial_fn, r, *args))
|
||||
continue
|
||||
yield self._futures.popleft().result()
|
||||
|
||||
while self._futures:
|
||||
yield self._futures.popleft().result()
|
||||
|
||||
def _cancel_futures(self) -> None:
|
||||
if self._executor is not None:
|
||||
for f in self._futures:
|
||||
f.cancel()
|
||||
|
||||
def setup(
|
||||
self,
|
||||
inputs: List[ArrayIndexType],
|
||||
output: ArrayIndexType,
|
||||
size_dict: Dict[str, int],
|
||||
) -> Tuple[Any, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: List[ArrayIndexType],
|
||||
output: ArrayIndexType,
|
||||
size_dict: Dict[str, int],
|
||||
memory_limit: Optional[int] = None,
|
||||
) -> PathType:
|
||||
self._check_args_against_first_call(inputs, output, size_dict)
|
||||
|
||||
# start a timer?
|
||||
if self.max_time is not None:
|
||||
t0 = time.time()
|
||||
|
||||
trial_fn, trial_args = self.setup(inputs, output, size_dict)
|
||||
|
||||
r_start = self._repeats_start + len(self.costs)
|
||||
r_stop = r_start + self.max_repeats
|
||||
repeats = range(r_start, r_stop)
|
||||
|
||||
# create the trials lazily
|
||||
if self._executor is not None:
|
||||
trials = self._gen_results_parallel(repeats, trial_fn, trial_args)
|
||||
else:
|
||||
trials = (trial_fn(r, *trial_args) for r in repeats)
|
||||
|
||||
# assess the trials
|
||||
for ssa_path, cost, size in trials:
|
||||
# keep track of all costs and sizes
|
||||
self.costs.append(cost)
|
||||
self.sizes.append(size)
|
||||
|
||||
# check if we have found a new best
|
||||
found_new_best = self.better(cost, size, self.best["flops"], self.best["size"])
|
||||
|
||||
if found_new_best:
|
||||
self.best["flops"] = cost
|
||||
self.best["size"] = size
|
||||
self.best["ssa_path"] = ssa_path
|
||||
|
||||
# check if we have run out of time
|
||||
if (self.max_time is not None) and (time.time() > t0 + self.max_time):
|
||||
break
|
||||
|
||||
self._cancel_futures()
|
||||
return self.path
|
||||
|
||||
def __del__(self):
|
||||
# if we created the parallel pool-executor, shut it down
|
||||
if getattr(self, "_managing_executor", False):
|
||||
self._executor.shutdown()
|
||||
|
||||
|
||||
def thermal_chooser(queue, remaining, nbranch=8, temperature=1, rel_temperature=True):
|
||||
"""A contraction 'chooser' that weights possible contractions using a
|
||||
Boltzmann distribution. Explicitly, given costs `c_i` (with `c_0` the
|
||||
smallest), the relative weights, `w_i`, are computed as:
|
||||
|
||||
$$w_i = exp( -(c_i - c_0) / temperature)$$
|
||||
|
||||
Additionally, if `rel_temperature` is set, scale `temperature` by
|
||||
`abs(c_0)` to account for likely fluctuating cost magnitudes during the
|
||||
course of a contraction.
|
||||
|
||||
Parameters:
|
||||
queue: The heapified list of candidate contractions.
|
||||
remaining: Mapping of remaining inputs' indices to the ssa id.
|
||||
temperature: When choosing a possible contraction, its relative probability will be
|
||||
proportional to `exp(-cost / temperature)`. Thus the larger
|
||||
`temperature` is, the further random paths will stray from the normal
|
||||
'greedy' path. Conversely, if set to zero, only paths with exactly the
|
||||
same cost as the best at each step will be explored.
|
||||
rel_temperature: Whether to normalize the `temperature` at each step to the scale of
|
||||
the best cost. This is generally beneficial as the magnitude of costs
|
||||
can vary significantly throughout a contraction.
|
||||
nbranch: How many potential paths to calculate probability for and choose from at each step.
|
||||
|
||||
Returns:
|
||||
cost
|
||||
k1
|
||||
k2
|
||||
k3
|
||||
"""
|
||||
n = 0
|
||||
choices = []
|
||||
while queue and n < nbranch:
|
||||
cost, k1, k2, k12 = heapq.heappop(queue)
|
||||
if k1 not in remaining or k2 not in remaining:
|
||||
continue # candidate is obsolete
|
||||
choices.append((cost, k1, k2, k12))
|
||||
n += 1
|
||||
|
||||
if n == 0:
|
||||
return None
|
||||
if n == 1:
|
||||
return choices[0]
|
||||
|
||||
costs = [choice[0][0] for choice in choices]
|
||||
cmin = costs[0]
|
||||
|
||||
# adjust by the overall scale to account for fluctuating absolute costs
|
||||
if rel_temperature:
|
||||
temperature *= max(1, abs(cmin))
|
||||
|
||||
# compute relative probability for each potential contraction
|
||||
if temperature == 0.0:
|
||||
energies = [1 if c == cmin else 0 for c in costs]
|
||||
else:
|
||||
# shift by cmin for numerical reasons
|
||||
energies = [math.exp(-(c - cmin) / temperature) for c in costs]
|
||||
|
||||
# randomly choose a contraction based on energies
|
||||
(chosen,) = random_choices(range(n), weights=energies)
|
||||
cost, k1, k2, k12 = choices.pop(chosen)
|
||||
|
||||
# put the other choice back in the heap
|
||||
for other in choices:
|
||||
heapq.heappush(queue, other)
|
||||
|
||||
return cost, k1, k2, k12
|
||||
|
||||
|
||||
def ssa_path_compute_cost(
|
||||
ssa_path: PathType,
|
||||
inputs: List[ArrayIndexType],
|
||||
output: ArrayIndexType,
|
||||
size_dict: Dict[str, int],
|
||||
) -> Tuple[int, int]:
|
||||
"""Compute the flops and max size of an ssa path."""
|
||||
inputs = list(map(frozenset, inputs))
|
||||
output = frozenset(output)
|
||||
remaining = set(range(len(inputs)))
|
||||
total_cost = 0
|
||||
max_size = 0
|
||||
|
||||
for i, j in ssa_path:
|
||||
k12, flops12 = paths.calc_k12_flops(inputs, output, remaining, i, j, size_dict) # type: ignore
|
||||
remaining.discard(i)
|
||||
remaining.discard(j)
|
||||
remaining.add(len(inputs))
|
||||
inputs.append(k12)
|
||||
total_cost += flops12
|
||||
max_size = max(max_size, helpers.compute_size_by_dict(k12, size_dict))
|
||||
|
||||
return total_cost, max_size
|
||||
|
||||
|
||||
def _trial_greedy_ssa_path_and_cost(
|
||||
r: int,
|
||||
inputs: List[ArrayIndexType],
|
||||
output: ArrayIndexType,
|
||||
size_dict: Dict[str, int],
|
||||
choose_fn: Any,
|
||||
cost_fn: Any,
|
||||
) -> Tuple[PathType, int, int]:
|
||||
"""A single, repeatable, greedy trial run. **Returns:** ``ssa_path`` and cost."""
|
||||
if r == 0:
|
||||
# always start with the standard greedy approach
|
||||
choose_fn = None
|
||||
|
||||
random_seed(r)
|
||||
|
||||
ssa_path = paths.ssa_greedy_optimize(inputs, output, size_dict, choose_fn, cost_fn)
|
||||
cost, size = ssa_path_compute_cost(ssa_path, inputs, output, size_dict)
|
||||
|
||||
return ssa_path, cost, size
|
||||
|
||||
|
||||
class RandomGreedy(RandomOptimizer):
|
||||
def __init__(
|
||||
self,
|
||||
cost_fn: str = "memory-removed-jitter",
|
||||
temperature: float = 1.0,
|
||||
rel_temperature: bool = True,
|
||||
nbranch: int = 8,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Parameters:
|
||||
cost_fn: A function that returns a heuristic 'cost' of a potential contraction
|
||||
with which to sort candidates. Should have signature
|
||||
`cost_fn(size12, size1, size2, k12, k1, k2)`.
|
||||
temperature: When choosing a possible contraction, its relative probability will be
|
||||
proportional to `exp(-cost / temperature)`. Thus the larger
|
||||
`temperature` is, the further random paths will stray from the normal
|
||||
'greedy' path. Conversely, if set to zero, only paths with exactly the
|
||||
same cost as the best at each step will be explored.
|
||||
rel_temperature: Whether to normalize the ``temperature`` at each step to the scale of
|
||||
the best cost. This is generally beneficial as the magnitude of costs
|
||||
can vary significantly throughout a contraction. If False, the
|
||||
algorithm will end up branching when the absolute cost is low, but
|
||||
stick to the 'greedy' path when the cost is high - this can also be
|
||||
beneficial.
|
||||
nbranch: How many potential paths to calculate probability for and choose from at each step.
|
||||
kwargs: Supplied to RandomOptimizer.
|
||||
"""
|
||||
self.cost_fn = cost_fn
|
||||
self.temperature = temperature
|
||||
self.rel_temperature = rel_temperature
|
||||
self.nbranch = nbranch
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def choose_fn(self) -> Any:
|
||||
"""The function that chooses which contraction to take - make this a
|
||||
property so that ``temperature`` and ``nbranch`` etc. can be updated
|
||||
between runs.
|
||||
"""
|
||||
if self.nbranch == 1:
|
||||
return None
|
||||
|
||||
return functools.partial(
|
||||
thermal_chooser,
|
||||
temperature=self.temperature,
|
||||
nbranch=self.nbranch,
|
||||
rel_temperature=self.rel_temperature,
|
||||
)
|
||||
|
||||
def setup(
|
||||
self,
|
||||
inputs: List[ArrayIndexType],
|
||||
output: ArrayIndexType,
|
||||
size_dict: Dict[str, int],
|
||||
) -> Tuple[Any, Any]:
|
||||
fn = _trial_greedy_ssa_path_and_cost
|
||||
args = (inputs, output, size_dict, self.choose_fn, self.cost_fn)
|
||||
return fn, args
|
||||
|
||||
|
||||
def random_greedy(
|
||||
inputs: List[ArrayIndexType],
|
||||
output: ArrayIndexType,
|
||||
idx_dict: Dict[str, int],
|
||||
memory_limit: Optional[int] = None,
|
||||
**optimizer_kwargs: Any,
|
||||
) -> ArrayType:
|
||||
"""A simple wrapper around the `RandomGreedy` optimizer."""
|
||||
optimizer = RandomGreedy(**optimizer_kwargs)
|
||||
return optimizer(inputs, output, idx_dict, memory_limit)
|
||||
|
||||
|
||||
random_greedy_128 = functools.partial(random_greedy, max_repeats=128)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,216 @@
|
||||
"""A module for sharing intermediates between contractions.
|
||||
|
||||
Copyright (c) 2018 Uber Technologies
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import numbers
|
||||
import threading
|
||||
from collections import Counter, defaultdict
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
|
||||
from typing import Counter as CounterType
|
||||
|
||||
from opt_einsum.parser import alpha_canonicalize, parse_einsum_input
|
||||
from opt_einsum.typing import ArrayType
|
||||
|
||||
CacheKeyType = Union[Tuple[str, str, int, Tuple[int, ...]], Tuple[str, int]]
|
||||
CacheType = Dict[CacheKeyType, ArrayType]
|
||||
|
||||
__all__ = [
|
||||
"currently_sharing",
|
||||
"get_sharing_cache",
|
||||
"shared_intermediates",
|
||||
"count_cached_ops",
|
||||
"transpose_cache_wrap",
|
||||
"einsum_cache_wrap",
|
||||
"to_backend_cache_wrap",
|
||||
]
|
||||
|
||||
_SHARING_STACK: Dict[int, List[CacheType]] = defaultdict(list)
|
||||
|
||||
|
||||
def currently_sharing() -> bool:
|
||||
"""Check if we are currently sharing a cache -- thread specific."""
|
||||
return threading.get_ident() in _SHARING_STACK
|
||||
|
||||
|
||||
def get_sharing_cache() -> CacheType:
|
||||
"""Return the most recent sharing cache -- thread specific."""
|
||||
return _SHARING_STACK[threading.get_ident()][-1]
|
||||
|
||||
|
||||
def _add_sharing_cache(cache: CacheType) -> Any:
|
||||
_SHARING_STACK[threading.get_ident()].append(cache)
|
||||
|
||||
|
||||
def _remove_sharing_cache() -> None:
|
||||
tid = threading.get_ident()
|
||||
_SHARING_STACK[tid].pop()
|
||||
if not _SHARING_STACK[tid]:
|
||||
del _SHARING_STACK[tid]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def shared_intermediates(
|
||||
cache: Optional[CacheType] = None,
|
||||
) -> Generator[CacheType, None, None]:
|
||||
"""Context in which contract intermediate results are shared.
|
||||
|
||||
Note that intermediate computations will not be garbage collected until
|
||||
1. this context exits, and
|
||||
2. the yielded cache is garbage collected (if it was captured).
|
||||
|
||||
**Parameters:**
|
||||
|
||||
- **cache** - *(dict)* If specified, a user-stored dict in which intermediate results will be stored. This can be used to interleave sharing contexts.
|
||||
|
||||
**Returns:**
|
||||
|
||||
- **cache** - *(dict)* A dictionary in which sharing results are stored. If ignored,
|
||||
sharing results will be garbage collected when this context is
|
||||
exited. This dict can be passed to another context to resume
|
||||
sharing.
|
||||
"""
|
||||
if cache is None:
|
||||
cache = {}
|
||||
_add_sharing_cache(cache)
|
||||
try:
|
||||
yield cache
|
||||
finally:
|
||||
_remove_sharing_cache()
|
||||
|
||||
|
||||
def count_cached_ops(cache: CacheType) -> CounterType[str]:
|
||||
"""Returns a counter of the types of each op in the cache.
|
||||
This is useful for profiling to increase sharing.
|
||||
"""
|
||||
return Counter(key[0] for key in cache.keys())
|
||||
|
||||
|
||||
def _save_tensors(*tensors: ArrayType) -> None:
|
||||
"""Save tensors in the cache to prevent their ids from being recycled.
|
||||
This is needed to prevent false cache lookups.
|
||||
"""
|
||||
cache = get_sharing_cache()
|
||||
for tensor in tensors:
|
||||
cache["tensor", id(tensor)] = tensor
|
||||
|
||||
|
||||
def _memoize(key: CacheKeyType, fn: Any, *args: Any, **kwargs: Any) -> ArrayType:
|
||||
"""Memoize ``fn(*args, **kwargs)`` using the given ``key``.
|
||||
Results will be stored in the innermost ``cache`` yielded by
|
||||
:func:`shared_intermediates`.
|
||||
"""
|
||||
cache = get_sharing_cache()
|
||||
if key in cache:
|
||||
return cache[key]
|
||||
result = fn(*args, **kwargs)
|
||||
cache[key] = result
|
||||
return result
|
||||
|
||||
|
||||
def transpose_cache_wrap(transpose: Any) -> Any:
|
||||
"""Decorates a ``transpose()`` implementation to be memoized inside a
|
||||
:func:`shared_intermediates` context.
|
||||
"""
|
||||
|
||||
@functools.wraps(transpose)
|
||||
def cached_transpose(a, axes, backend="numpy"):
|
||||
if not currently_sharing():
|
||||
return transpose(a, axes, backend=backend)
|
||||
|
||||
# hash by axes
|
||||
_save_tensors(a)
|
||||
axes = tuple(axes)
|
||||
key = "transpose", backend, id(a), axes
|
||||
return _memoize(key, transpose, a, axes, backend=backend)
|
||||
|
||||
return cached_transpose
|
||||
|
||||
|
||||
def tensordot_cache_wrap(tensordot: Any) -> Any:
|
||||
"""Decorates a ``tensordot()`` implementation to be memoized inside a
|
||||
:func:`shared_intermediates` context.
|
||||
"""
|
||||
|
||||
@functools.wraps(tensordot)
|
||||
def cached_tensordot(x, y, axes=2, backend="numpy"):
|
||||
if not currently_sharing():
|
||||
return tensordot(x, y, axes, backend=backend)
|
||||
|
||||
# hash based on the (axes_x,axes_y) form of axes
|
||||
_save_tensors(x, y)
|
||||
if isinstance(axes, numbers.Number):
|
||||
axes = (
|
||||
list(range(len(x.shape)))[len(x.shape) - axes :],
|
||||
list(range(len(y.shape)))[:axes],
|
||||
)
|
||||
axes = tuple(axes[0]), tuple(axes[1])
|
||||
key = "tensordot", backend, id(x), id(y), axes
|
||||
return _memoize(key, tensordot, x, y, axes, backend=backend)
|
||||
|
||||
return cached_tensordot
|
||||
|
||||
|
||||
def einsum_cache_wrap(einsum: Any) -> Any:
|
||||
"""Decorates an ``einsum()`` implementation to be memoized inside a
|
||||
:func:`shared_intermediates` context.
|
||||
"""
|
||||
|
||||
@functools.wraps(einsum)
|
||||
def cached_einsum(*args, **kwargs):
|
||||
if not currently_sharing():
|
||||
return einsum(*args, **kwargs)
|
||||
|
||||
# hash modulo commutativity by computing a canonical ordering and names
|
||||
backend = kwargs.pop("backend", "numpy")
|
||||
equation = args[0]
|
||||
inputs, output, operands = parse_einsum_input(args)
|
||||
inputs = inputs.split(",")
|
||||
|
||||
_save_tensors(*operands)
|
||||
|
||||
# Build canonical key
|
||||
canonical = sorted(zip(inputs, map(id, operands)), key=lambda x: x[1])
|
||||
canonical_ids = tuple(id_ for _, id_ in canonical)
|
||||
canonical_inputs = ",".join(input_ for input_, _ in canonical)
|
||||
canonical_equation = alpha_canonicalize(canonical_inputs + "->" + output)
|
||||
|
||||
key = "einsum", backend, canonical_equation, canonical_ids
|
||||
return _memoize(key, einsum, equation, *operands, backend=backend)
|
||||
|
||||
return cached_einsum
|
||||
|
||||
|
||||
def to_backend_cache_wrap(to_backend: Any = None, constants: Any = False) -> Any:
|
||||
"""Decorates an ``to_backend()`` implementation to be memoized inside a
|
||||
:func:`shared_intermediates` context (e.g. ``to_cupy``, ``to_torch``).
|
||||
"""
|
||||
# manage the case that decorator is called with args
|
||||
if to_backend is None:
|
||||
return functools.partial(to_backend_cache_wrap, constants=constants)
|
||||
|
||||
if constants:
|
||||
|
||||
@functools.wraps(to_backend)
|
||||
def cached_to_backend(array, constant=False):
|
||||
if not currently_sharing():
|
||||
return to_backend(array, constant=constant)
|
||||
|
||||
# hash by id
|
||||
key = to_backend.__name__, id(array), constant
|
||||
return _memoize(key, to_backend, array, constant=constant)
|
||||
|
||||
else:
|
||||
|
||||
@functools.wraps(to_backend)
|
||||
def cached_to_backend(array):
|
||||
if not currently_sharing():
|
||||
return to_backend(array)
|
||||
|
||||
# hash by id
|
||||
key = to_backend.__name__, id(array)
|
||||
return _memoize(key, to_backend, array)
|
||||
|
||||
return cached_to_backend
|
||||
@@ -0,0 +1,224 @@
|
||||
"""Testing routines for opt_einsum."""
|
||||
|
||||
import random
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, overload
|
||||
|
||||
import pytest
|
||||
|
||||
from opt_einsum.parser import get_symbol
|
||||
from opt_einsum.typing import ArrayType, PathType, TensorShapeType
|
||||
|
||||
_valid_chars = "abcdefghijklmopqABC"
|
||||
_sizes = [2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3, 2, 5, 7, 4, 3, 2, 3, 4]
|
||||
_default_dim_dict = dict(zip(_valid_chars, _sizes))
|
||||
|
||||
|
||||
def build_shapes(string: str, dimension_dict: Optional[Dict[str, int]] = None) -> Tuple[TensorShapeType, ...]:
|
||||
"""Builds random tensor shapes for testing.
|
||||
|
||||
Parameters:
|
||||
string: List of tensor strings to build
|
||||
dimension_dict: Dictionary of index sizes, defaults to indices size of 2-7
|
||||
|
||||
Returns:
|
||||
The resulting shapes.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
>>> shapes = build_shapes('abbc', {'a': 2, 'b':3, 'c':5})
|
||||
>>> shapes
|
||||
[(2, 3), (3, 3, 5), (5,)]
|
||||
```
|
||||
|
||||
"""
|
||||
if dimension_dict is None:
|
||||
dimension_dict = _default_dim_dict
|
||||
|
||||
shapes = []
|
||||
terms = string.split("->")[0].split(",")
|
||||
for term in terms:
|
||||
dims = [dimension_dict[x] for x in term]
|
||||
shapes.append(tuple(dims))
|
||||
return tuple(shapes)
|
||||
|
||||
|
||||
def build_views(
|
||||
string: str, dimension_dict: Optional[Dict[str, int]] = None, array_function: Optional[Any] = None
|
||||
) -> Tuple[ArrayType]:
|
||||
"""Builds random numpy arrays for testing.
|
||||
|
||||
Parameters:
|
||||
string: List of tensor strings to build
|
||||
dimension_dict: Dictionary of index _sizes
|
||||
array_function: Function to build the arrays, defaults to np.random.rand
|
||||
|
||||
Returns:
|
||||
The resulting views.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
>>> view = build_views('abbc', {'a': 2, 'b':3, 'c':5})
|
||||
>>> view[0].shape
|
||||
(2, 3, 3, 5)
|
||||
```
|
||||
|
||||
"""
|
||||
if array_function is None:
|
||||
np = pytest.importorskip("numpy")
|
||||
array_function = np.random.rand
|
||||
|
||||
views = []
|
||||
for shape in build_shapes(string, dimension_dict=dimension_dict):
|
||||
if shape:
|
||||
views.append(array_function(*shape))
|
||||
else:
|
||||
views.append(random.random())
|
||||
return tuple(views)
|
||||
|
||||
|
||||
@overload
|
||||
def rand_equation(
|
||||
n: int,
|
||||
regularity: int,
|
||||
n_out: int = ...,
|
||||
d_min: int = ...,
|
||||
d_max: int = ...,
|
||||
seed: Optional[int] = ...,
|
||||
global_dim: bool = ...,
|
||||
*,
|
||||
return_size_dict: Literal[True],
|
||||
) -> Tuple[str, PathType, Dict[str, int]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def rand_equation(
|
||||
n: int,
|
||||
regularity: int,
|
||||
n_out: int = ...,
|
||||
d_min: int = ...,
|
||||
d_max: int = ...,
|
||||
seed: Optional[int] = ...,
|
||||
global_dim: bool = ...,
|
||||
return_size_dict: Literal[False] = ...,
|
||||
) -> Tuple[str, PathType]: ...
|
||||
|
||||
|
||||
def rand_equation(
|
||||
n: int,
|
||||
regularity: int,
|
||||
n_out: int = 0,
|
||||
d_min: int = 2,
|
||||
d_max: int = 9,
|
||||
seed: Optional[int] = None,
|
||||
global_dim: bool = False,
|
||||
return_size_dict: bool = False,
|
||||
) -> Union[Tuple[str, PathType, Dict[str, int]], Tuple[str, PathType]]:
|
||||
"""Generate a random contraction and shapes.
|
||||
|
||||
Parameters:
|
||||
n: Number of array arguments.
|
||||
regularity: 'Regularity' of the contraction graph. This essentially determines how
|
||||
many indices each tensor shares with others on average.
|
||||
n_out: Number of output indices (i.e. the number of non-contracted indices).
|
||||
Defaults to 0, i.e., a contraction resulting in a scalar.
|
||||
d_min: Minimum dimension size.
|
||||
d_max: Maximum dimension size.
|
||||
seed: If not None, seed numpy's random generator with this.
|
||||
global_dim: Add a global, 'broadcast', dimension to every operand.
|
||||
return_size_dict: Return the mapping of indices to sizes.
|
||||
|
||||
Returns:
|
||||
eq: The equation string.
|
||||
shapes: The array shapes.
|
||||
size_dict: The dict of index sizes, only returned if ``return_size_dict=True``.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
>>> eq, shapes = rand_equation(n=10, regularity=4, n_out=5, seed=42)
|
||||
>>> eq
|
||||
'oyeqn,tmaq,skpo,vg,hxui,n,fwxmr,hitplcj,kudlgfv,rywjsb->cebda'
|
||||
|
||||
>>> shapes
|
||||
[(9, 5, 4, 5, 4),
|
||||
(4, 4, 8, 5),
|
||||
(9, 4, 6, 9),
|
||||
(6, 6),
|
||||
(6, 9, 7, 8),
|
||||
(4,),
|
||||
(9, 3, 9, 4, 9),
|
||||
(6, 8, 4, 6, 8, 6, 3),
|
||||
(4, 7, 8, 8, 6, 9, 6),
|
||||
(9, 5, 3, 3, 9, 5)]
|
||||
```
|
||||
"""
|
||||
np = pytest.importorskip("numpy")
|
||||
if seed is not None:
|
||||
np.random.seed(seed)
|
||||
|
||||
# total number of indices
|
||||
num_inds = n * regularity // 2 + n_out
|
||||
inputs = ["" for _ in range(n)]
|
||||
output = []
|
||||
|
||||
size_dict = {get_symbol(i): np.random.randint(d_min, d_max + 1) for i in range(num_inds)}
|
||||
|
||||
# generate a list of indices to place either once or twice
|
||||
def gen():
|
||||
for i, ix in enumerate(size_dict):
|
||||
# generate an outer index
|
||||
if i < n_out:
|
||||
output.append(ix)
|
||||
yield ix
|
||||
# generate a bond
|
||||
else:
|
||||
yield ix
|
||||
yield ix
|
||||
|
||||
# add the indices randomly to the inputs
|
||||
for i, ix in enumerate(np.random.permutation(list(gen()))):
|
||||
# make sure all inputs have at least one index
|
||||
if i < n:
|
||||
inputs[i] += ix
|
||||
else:
|
||||
# don't add any traces on same op
|
||||
where = np.random.randint(0, n)
|
||||
while ix in inputs[where]:
|
||||
where = np.random.randint(0, n)
|
||||
|
||||
inputs[where] += ix
|
||||
|
||||
# possibly add the same global dim to every arg
|
||||
if global_dim:
|
||||
gdim = get_symbol(num_inds)
|
||||
size_dict[gdim] = np.random.randint(d_min, d_max + 1)
|
||||
for i in range(n):
|
||||
inputs[i] += gdim
|
||||
output += gdim
|
||||
|
||||
# randomly transpose the output indices and form equation
|
||||
output = "".join(np.random.permutation(output)) # type: ignore
|
||||
eq = "{}->{}".format(",".join(inputs), output)
|
||||
|
||||
# make the shapes
|
||||
shapes = [tuple(size_dict[ix] for ix in op) for op in inputs]
|
||||
|
||||
ret = (eq, shapes)
|
||||
|
||||
if return_size_dict:
|
||||
return ret + (size_dict,)
|
||||
else:
|
||||
return ret
|
||||
|
||||
|
||||
def build_arrays_from_tuples(path: PathType) -> List[Any]:
|
||||
"""Build random numpy arrays from a path.
|
||||
|
||||
Parameters:
|
||||
path: The path to build arrays from.
|
||||
|
||||
Returns:
|
||||
The resulting arrays.
|
||||
"""
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
return [np.random.rand(*x) for x in path]
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,462 @@
|
||||
from typing import Set
|
||||
|
||||
import pytest
|
||||
|
||||
from opt_einsum import backends, contract, contract_expression, sharing
|
||||
from opt_einsum.contract import ArrayShaped, infer_backend, parse_backend
|
||||
from opt_einsum.testing import build_views
|
||||
|
||||
try:
|
||||
# needed so tensorflow doesn't allocate all gpu mem
|
||||
try:
|
||||
from tensorflow import ConfigProto # type: ignore
|
||||
from tensorflow import Session as TFSession
|
||||
except ImportError:
|
||||
from tensorflow.compat.v1 import ConfigProto # type: ignore
|
||||
from tensorflow.compat.v1 import Session as TFSession
|
||||
_TF_CONFIG = ConfigProto()
|
||||
_TF_CONFIG.gpu_options.allow_growth = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
tests = [
|
||||
"ab,bc->ca",
|
||||
"abc,bcd,dea",
|
||||
"abc,def->fedcba",
|
||||
"abc,bcd,df->fa",
|
||||
# test 'prefer einsum' ops
|
||||
"ijk,ikj",
|
||||
"i,j->ij",
|
||||
"ijk,k->ij",
|
||||
"AB,BC->CA",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("string", tests)
|
||||
def test_tensorflow(string: str) -> None:
|
||||
np = pytest.importorskip("numpy")
|
||||
pytest.importorskip("tensorflow")
|
||||
|
||||
views = build_views(string)
|
||||
ein = contract(string, *views, optimize=False, use_blas=False)
|
||||
opt = np.empty_like(ein)
|
||||
|
||||
shps = [v.shape for v in views]
|
||||
expr = contract_expression(string, *shps, optimize=True)
|
||||
|
||||
sess = TFSession(config=_TF_CONFIG)
|
||||
with sess.as_default():
|
||||
expr(*views, backend="tensorflow", out=opt)
|
||||
sess.close()
|
||||
|
||||
assert np.allclose(ein, opt)
|
||||
|
||||
# test non-conversion mode
|
||||
tensorflow_views = [backends.to_tensorflow(view) for view in views]
|
||||
expr(*tensorflow_views)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}])
|
||||
def test_tensorflow_with_constants(constants: Set[int]) -> None:
|
||||
np = pytest.importorskip("numpy")
|
||||
tf = pytest.importorskip("tensorflow")
|
||||
|
||||
eq = "ij,jk,kl->li"
|
||||
shapes = (2, 3), (3, 4), (4, 5)
|
||||
(non_const,) = {0, 1, 2} - constants
|
||||
ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)]
|
||||
var = np.random.rand(*shapes[non_const])
|
||||
res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3)))
|
||||
|
||||
expr = contract_expression(eq, *ops, constants=constants)
|
||||
|
||||
# check tensorflow
|
||||
with TFSession(config=_TF_CONFIG).as_default():
|
||||
res_got = expr(var, backend="tensorflow")
|
||||
assert all(
|
||||
array is None or infer_backend(array) == "tensorflow" for array in expr._evaluated_constants["tensorflow"]
|
||||
)
|
||||
assert np.allclose(res_exp, res_got)
|
||||
|
||||
# check can call with numpy still
|
||||
res_got2 = expr(var, backend="numpy")
|
||||
assert np.allclose(res_exp, res_got2)
|
||||
|
||||
# check tensorflow call returns tensorflow still
|
||||
res_got3 = expr(backends.to_tensorflow(var))
|
||||
assert isinstance(res_got3, tf.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("string", tests)
|
||||
def test_tensorflow_with_sharing(string: str) -> None:
|
||||
np = pytest.importorskip("numpy")
|
||||
tf = pytest.importorskip("tensorflow")
|
||||
|
||||
views = build_views(string)
|
||||
ein = contract(string, *views, optimize=False, use_blas=False)
|
||||
|
||||
shps = [v.shape for v in views]
|
||||
expr = contract_expression(string, *shps, optimize=True)
|
||||
|
||||
sess = TFSession(config=_TF_CONFIG)
|
||||
|
||||
with sess.as_default(), sharing.shared_intermediates() as cache:
|
||||
tfl1 = expr(*views, backend="tensorflow")
|
||||
assert sharing.get_sharing_cache() is cache
|
||||
cache_sz = len(cache)
|
||||
assert cache_sz > 0
|
||||
tfl2 = expr(*views, backend="tensorflow")
|
||||
assert len(cache) == cache_sz
|
||||
|
||||
assert all(isinstance(t, tf.Tensor) for t in cache.values())
|
||||
|
||||
assert np.allclose(ein, tfl1)
|
||||
assert np.allclose(ein, tfl2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("string", tests)
|
||||
def test_theano(string: str) -> None:
|
||||
np = pytest.importorskip("numpy")
|
||||
theano = pytest.importorskip("theano")
|
||||
|
||||
views = build_views(string)
|
||||
ein = contract(string, *views, optimize=False, use_blas=False)
|
||||
shps = [v.shape for v in views]
|
||||
|
||||
expr = contract_expression(string, *shps, optimize=True)
|
||||
|
||||
opt = expr(*views, backend="theano")
|
||||
assert np.allclose(ein, opt)
|
||||
|
||||
# test non-conversion mode
|
||||
theano_views = [backends.to_theano(view) for view in views]
|
||||
theano_opt = expr(*theano_views)
|
||||
assert isinstance(theano_opt, theano.tensor.TensorVariable)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}])
|
||||
def test_theano_with_constants(constants: Set[int]) -> None:
|
||||
np = pytest.importorskip("numpy")
|
||||
theano = pytest.importorskip("theano")
|
||||
|
||||
eq = "ij,jk,kl->li"
|
||||
shapes = (2, 3), (3, 4), (4, 5)
|
||||
(non_const,) = {0, 1, 2} - constants
|
||||
ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)]
|
||||
var = np.random.rand(*shapes[non_const])
|
||||
res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3)))
|
||||
|
||||
expr = contract_expression(eq, *ops, constants=constants)
|
||||
|
||||
# check theano
|
||||
res_got = expr(var, backend="theano")
|
||||
assert all(array is None or infer_backend(array) == "theano" for array in expr._evaluated_constants["theano"])
|
||||
assert np.allclose(res_exp, res_got)
|
||||
|
||||
# check can call with numpy still
|
||||
res_got2 = expr(var, backend="numpy")
|
||||
assert np.allclose(res_exp, res_got2)
|
||||
|
||||
# check theano call returns theano still
|
||||
res_got3 = expr(backends.to_theano(var))
|
||||
assert isinstance(res_got3, theano.tensor.TensorVariable)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("string", tests)
|
||||
def test_theano_with_sharing(string: str) -> None:
|
||||
np = pytest.importorskip("numpy")
|
||||
theano = pytest.importorskip("theano")
|
||||
|
||||
views = build_views(string)
|
||||
ein = contract(string, *views, optimize=False, use_blas=False)
|
||||
|
||||
shps = [v.shape for v in views]
|
||||
expr = contract_expression(string, *shps, optimize=True)
|
||||
|
||||
with sharing.shared_intermediates() as cache:
|
||||
thn1 = expr(*views, backend="theano")
|
||||
assert sharing.get_sharing_cache() is cache
|
||||
cache_sz = len(cache)
|
||||
assert cache_sz > 0
|
||||
thn2 = expr(*views, backend="theano")
|
||||
assert len(cache) == cache_sz
|
||||
|
||||
assert all(isinstance(t, theano.tensor.TensorVariable) for t in cache.values())
|
||||
|
||||
assert np.allclose(ein, thn1)
|
||||
assert np.allclose(ein, thn2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("string", tests)
|
||||
def test_cupy(string: str) -> None:
|
||||
np = pytest.importorskip("numpy") # pragma: no cover
|
||||
cupy = pytest.importorskip("cupy")
|
||||
|
||||
views = build_views(string)
|
||||
ein = contract(string, *views, optimize=False, use_blas=False)
|
||||
shps = [v.shape for v in views]
|
||||
|
||||
expr = contract_expression(string, *shps, optimize=True)
|
||||
|
||||
opt = expr(*views, backend="cupy")
|
||||
assert np.allclose(ein, opt)
|
||||
|
||||
# test non-conversion mode
|
||||
cupy_views = [backends.to_cupy(view) for view in views]
|
||||
cupy_opt = expr(*cupy_views)
|
||||
assert isinstance(cupy_opt, cupy.ndarray)
|
||||
assert np.allclose(ein, cupy.asnumpy(cupy_opt))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}])
|
||||
def test_cupy_with_constants(constants: Set[int]) -> None:
|
||||
np = pytest.importorskip("numpy") # pragma: no cover
|
||||
cupy = pytest.importorskip("cupy")
|
||||
|
||||
eq = "ij,jk,kl->li"
|
||||
shapes = (2, 3), (3, 4), (4, 5)
|
||||
(non_const,) = {0, 1, 2} - constants
|
||||
ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)]
|
||||
var = np.random.rand(*shapes[non_const])
|
||||
res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3)))
|
||||
|
||||
expr = contract_expression(eq, *ops, constants=constants)
|
||||
|
||||
# check cupy
|
||||
res_got = expr(var, backend="cupy")
|
||||
# check cupy versions of constants exist
|
||||
assert all(array is None or infer_backend(array) == "cupy" for array in expr._evaluated_constants["cupy"])
|
||||
assert np.allclose(res_exp, res_got)
|
||||
|
||||
# check can call with numpy still
|
||||
res_got2 = expr(var, backend="numpy")
|
||||
assert np.allclose(res_exp, res_got2)
|
||||
|
||||
# check cupy call returns cupy still
|
||||
res_got3 = expr(cupy.asarray(var))
|
||||
assert isinstance(res_got3, cupy.ndarray)
|
||||
assert np.allclose(res_exp, res_got3.get())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("string", tests)
|
||||
def test_jax(string: str) -> None:
|
||||
np = pytest.importorskip("numpy") # pragma: no cover
|
||||
pytest.importorskip("jax")
|
||||
|
||||
views = build_views(string)
|
||||
ein = contract(string, *views, optimize=False, use_blas=False)
|
||||
shps = [v.shape for v in views]
|
||||
|
||||
expr = contract_expression(string, *shps, optimize=True)
|
||||
|
||||
opt = expr(*views, backend="jax")
|
||||
assert np.allclose(ein, opt)
|
||||
assert isinstance(opt, np.ndarray)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}])
|
||||
def test_jax_with_constants(constants: Set[int]) -> None:
|
||||
jax = pytest.importorskip("jax")
|
||||
key = jax.random.PRNGKey(42)
|
||||
|
||||
eq = "ij,jk,kl->li"
|
||||
shapes = (2, 3), (3, 4), (4, 5)
|
||||
(non_const,) = {0, 1, 2} - constants
|
||||
ops = [jax.random.uniform(key, shp) if i in constants else shp for i, shp in enumerate(shapes)]
|
||||
var = jax.random.uniform(key, shapes[non_const])
|
||||
res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3)))
|
||||
|
||||
expr = contract_expression(eq, *ops, constants=constants)
|
||||
|
||||
# check jax
|
||||
res_got = expr(var, backend="jax")
|
||||
# check jax versions of constants exist
|
||||
assert all(array is None or infer_backend(array).startswith("jax") for array in expr._evaluated_constants["jax"])
|
||||
assert jax.numpy.sum(jax.numpy.abs(res_exp - res_got)) < 1e-8
|
||||
|
||||
|
||||
def test_jax_jit_gradient() -> None:
|
||||
jax = pytest.importorskip("jax")
|
||||
key = jax.random.PRNGKey(42)
|
||||
|
||||
eq = "ij,jk,kl->"
|
||||
shapes = (2, 3), (3, 4), (4, 2)
|
||||
views = [jax.random.uniform(key, s) for s in shapes]
|
||||
expr = contract_expression(eq, *shapes)
|
||||
x0 = expr(*views)
|
||||
|
||||
jit_expr = jax.jit(expr)
|
||||
x1 = jit_expr(*views).item()
|
||||
assert x1 == pytest.approx(x0, rel=1e-5)
|
||||
|
||||
# jax only takes gradient w.r.t first argument
|
||||
grad_expr = jax.jit(jax.grad(lambda views: expr(*views)))
|
||||
view_grads = grad_expr(views)
|
||||
assert all(v1.shape == v2.shape for v1, v2 in zip(views, view_grads))
|
||||
|
||||
# taking a step along the gradient should reduce our 'loss'
|
||||
new_views = [v - 0.001 * dv for v, dv in zip(views, view_grads)]
|
||||
x2 = jit_expr(*new_views).item()
|
||||
assert x2 < x1
|
||||
|
||||
|
||||
def test_autograd_gradient() -> None:
|
||||
np = pytest.importorskip("numpy")
|
||||
autograd = pytest.importorskip("autograd")
|
||||
|
||||
eq = "ij,jk,kl->"
|
||||
shapes = (2, 3), (3, 4), (4, 2)
|
||||
views = [np.random.randn(*s) for s in shapes]
|
||||
expr = contract_expression(eq, *shapes)
|
||||
x0 = expr(*views)
|
||||
|
||||
# autograd only takes gradient w.r.t first argument
|
||||
grad_expr = autograd.grad(lambda views: expr(*views))
|
||||
view_grads = grad_expr(views)
|
||||
assert all(v1.shape == v2.shape for v1, v2 in zip(views, view_grads))
|
||||
|
||||
# taking a step along the gradient should reduce our 'loss'
|
||||
new_views = [v - 0.001 * dv for v, dv in zip(views, view_grads)]
|
||||
x1 = expr(*new_views)
|
||||
assert x1 < x0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("string", tests)
|
||||
def test_dask(string: str) -> None:
|
||||
np = pytest.importorskip("numpy")
|
||||
da = pytest.importorskip("dask.array")
|
||||
|
||||
views = build_views(string)
|
||||
ein = contract(string, *views, optimize=False, use_blas=False)
|
||||
shps = [v.shape for v in views]
|
||||
expr = contract_expression(string, *shps, optimize=True)
|
||||
|
||||
# test non-conversion mode
|
||||
da_views = [da.from_array(x, chunks=(2)) for x in views]
|
||||
da_opt = expr(*da_views)
|
||||
|
||||
# check type is maintained when not using numpy arrays
|
||||
assert isinstance(da_opt, da.Array)
|
||||
|
||||
assert np.allclose(ein, np.array(da_opt))
|
||||
|
||||
# try raw contract
|
||||
da_opt = contract(string, *da_views)
|
||||
assert isinstance(da_opt, da.Array)
|
||||
assert np.allclose(ein, np.array(da_opt))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("string", tests)
|
||||
def test_sparse(string: str) -> None:
|
||||
np = pytest.importorskip("numpy")
|
||||
sparse = pytest.importorskip("sparse")
|
||||
|
||||
views = build_views(string)
|
||||
|
||||
# sparsify views so they don't become dense during contraction
|
||||
for view in views:
|
||||
np.random.seed(42)
|
||||
mask = np.random.choice([False, True], view.shape, True, [0.05, 0.95])
|
||||
view[mask] = 0
|
||||
|
||||
ein = contract(string, *views, optimize=False, use_blas=False)
|
||||
shps = [v.shape for v in views]
|
||||
expr = contract_expression(string, *shps, optimize=True)
|
||||
|
||||
# test non-conversion mode
|
||||
sparse_views = [sparse.COO.from_numpy(x) for x in views]
|
||||
sparse_opt = expr(*sparse_views)
|
||||
|
||||
# If the expression returns a float, stop here
|
||||
if not ein.shape:
|
||||
assert pytest.approx(ein) == 0.0
|
||||
return
|
||||
|
||||
# check type is maintained when not using numpy arrays
|
||||
assert isinstance(sparse_opt, sparse.COO)
|
||||
assert np.allclose(ein, sparse_opt.todense())
|
||||
|
||||
# try raw contract
|
||||
sparse_opt = contract(string, *sparse_views)
|
||||
assert isinstance(sparse_opt, sparse.COO)
|
||||
assert np.allclose(ein, sparse_opt.todense())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("string", tests)
|
||||
def test_torch(string: str) -> None:
|
||||
torch = pytest.importorskip("torch")
|
||||
|
||||
views = build_views(string, array_function=torch.rand)
|
||||
ein = torch.einsum(string, *views)
|
||||
|
||||
shps = [v.shape for v in views]
|
||||
expr = contract_expression(string, *shps, optimize=True)
|
||||
|
||||
opt = expr(*views, backend="torch")
|
||||
torch.testing.assert_close(ein, opt)
|
||||
|
||||
# test non-conversion mode
|
||||
torch_views = [backends.to_torch(view) for view in views]
|
||||
torch_opt = expr(*torch_views)
|
||||
assert isinstance(torch_opt, torch.Tensor)
|
||||
torch.testing.assert_close(ein, torch_opt)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}])
|
||||
def test_torch_with_constants(constants: Set[int]) -> None:
|
||||
torch = pytest.importorskip("torch")
|
||||
|
||||
eq = "ij,jk,kl->li"
|
||||
shapes = (2, 3), (3, 4), (4, 5)
|
||||
(non_const,) = {0, 1, 2} - constants
|
||||
ops = [torch.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)]
|
||||
var = torch.rand(*shapes[non_const])
|
||||
res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3)), backend="torch")
|
||||
|
||||
expr = contract_expression(eq, *ops, constants=constants)
|
||||
|
||||
# check torch
|
||||
res_got = expr(var, backend="torch")
|
||||
assert all(array is None or infer_backend(array) == "torch" for array in expr._evaluated_constants["torch"])
|
||||
torch.testing.assert_close(res_exp, res_got)
|
||||
|
||||
# check can call with numpy still
|
||||
res_got2 = expr(var, backend="torch")
|
||||
torch.testing.assert_close(res_exp, res_got2)
|
||||
|
||||
# check torch call returns torch still
|
||||
res_got3 = expr(backends.to_torch(var))
|
||||
assert isinstance(res_got3, torch.Tensor)
|
||||
torch.testing.assert_close(res_exp, res_got3)
|
||||
|
||||
|
||||
def test_auto_backend_custom_array_no_tensordot() -> None:
|
||||
x = ArrayShaped((1, 2, 3))
|
||||
# Shaped is an array-like object defined by opt_einsum - which has no TDOT
|
||||
assert infer_backend(x) == "opt_einsum"
|
||||
assert parse_backend([x], "auto") == "numpy"
|
||||
assert parse_backend([x], None) == "numpy"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("string", tests)
|
||||
def test_object_arrays_backend(string: str) -> None:
|
||||
np = pytest.importorskip("numpy")
|
||||
views = build_views(string)
|
||||
ein = contract(string, *views, optimize=False, use_blas=False)
|
||||
assert ein.dtype != object
|
||||
|
||||
shps = [v.shape for v in views]
|
||||
expr = contract_expression(string, *shps, optimize=True)
|
||||
|
||||
obj_views = [view.astype(object) for view in views]
|
||||
|
||||
# try raw contract
|
||||
obj_opt = contract(string, *obj_views, backend="object")
|
||||
assert obj_opt.dtype == object
|
||||
assert np.allclose(ein, obj_opt.astype(float))
|
||||
|
||||
# test expression
|
||||
obj_opt = expr(*obj_views, backend="object")
|
||||
assert obj_opt.dtype == object
|
||||
assert np.allclose(ein, obj_opt.astype(float))
|
||||
@@ -0,0 +1,81 @@
|
||||
"""
|
||||
Tests the BLAS capability for the opt_einsum module.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from opt_einsum import blas, contract
|
||||
|
||||
blas_tests = [
|
||||
# DOT
|
||||
((["k", "k"], "", set("k")), "DOT"), # DDOT
|
||||
((["ijk", "ijk"], "", set("ijk")), "DOT"), # DDOT
|
||||
# GEMV?
|
||||
# GEMM
|
||||
((["ij", "jk"], "ik", set("j")), "GEMM"), # GEMM N N
|
||||
((["ijl", "jlk"], "ik", set("jl")), "GEMM"), # GEMM N N Tensor
|
||||
((["ij", "kj"], "ik", set("j")), "GEMM"), # GEMM N T
|
||||
((["ijl", "kjl"], "ik", set("jl")), "GEMM"), # GEMM N T Tensor
|
||||
((["ji", "jk"], "ik", set("j")), "GEMM"), # GEMM T N
|
||||
((["jli", "jlk"], "ik", set("jl")), "GEMM"), # GEMM T N Tensor
|
||||
((["ji", "kj"], "ik", set("j")), "GEMM"), # GEMM T T
|
||||
((["jli", "kjl"], "ik", set("jl")), "GEMM"), # GEMM T T Tensor
|
||||
# GEMM with final transpose
|
||||
((["ij", "jk"], "ki", set("j")), "GEMM"), # GEMM N N
|
||||
((["ijl", "jlk"], "ki", set("jl")), "GEMM"), # GEMM N N Tensor
|
||||
((["ij", "kj"], "ki", set("j")), "GEMM"), # GEMM N T
|
||||
((["ijl", "kjl"], "ki", set("jl")), "GEMM"), # GEMM N T Tensor
|
||||
((["ji", "jk"], "ki", set("j")), "GEMM"), # GEMM T N
|
||||
((["jli", "jlk"], "ki", set("jl")), "GEMM"), # GEMM T N Tensor
|
||||
((["ji", "kj"], "ki", set("j")), "GEMM"), # GEMM T T
|
||||
((["jli", "kjl"], "ki", set("jl")), "GEMM"), # GEMM T T Tensor
|
||||
# Tensor Dot (requires copy), lets not deal with this for now
|
||||
((["ilj", "jlk"], "ik", set("jl")), "TDOT"), # FT GEMM N N Tensor
|
||||
((["ijl", "ljk"], "ik", set("jl")), "TDOT"), # ST GEMM N N Tensor
|
||||
((["ilj", "kjl"], "ik", set("jl")), "TDOT"), # FT GEMM N T Tensor
|
||||
((["ijl", "klj"], "ik", set("jl")), "TDOT"), # ST GEMM N T Tensor
|
||||
((["lji", "jlk"], "ik", set("jl")), "TDOT"), # FT GEMM T N Tensor
|
||||
((["jli", "ljk"], "ik", set("jl")), "TDOT"), # ST GEMM T N Tensor
|
||||
((["lji", "jlk"], "ik", set("jl")), "TDOT"), # FT GEMM T N Tensor
|
||||
((["jli", "ljk"], "ik", set("jl")), "TDOT"), # ST GEMM T N Tensor
|
||||
# Tensor Dot (requires copy), lets not deal with this for now with transpose
|
||||
((["ilj", "jlk"], "ik", set("lj")), "TDOT"), # FT GEMM N N Tensor
|
||||
((["ijl", "ljk"], "ik", set("lj")), "TDOT"), # ST GEMM N N Tensor
|
||||
((["ilj", "kjl"], "ik", set("lj")), "TDOT"), # FT GEMM N T Tensor
|
||||
((["ijl", "klj"], "ik", set("lj")), "TDOT"), # ST GEMM N T Tensor
|
||||
((["lji", "jlk"], "ik", set("lj")), "TDOT"), # FT GEMM T N Tensor
|
||||
((["jli", "ljk"], "ik", set("lj")), "TDOT"), # ST GEMM T N Tensor
|
||||
((["lji", "jlk"], "ik", set("lj")), "TDOT"), # FT GEMM T N Tensor
|
||||
((["jli", "ljk"], "ik", set("lj")), "TDOT"), # ST GEMM T N Tensor
|
||||
# Other
|
||||
((["ijk", "ikj"], "", set("ijk")), "DOT/EINSUM"), # Transpose DOT
|
||||
((["i", "j"], "ij", set()), "OUTER/EINSUM"), # Outer
|
||||
((["ijk", "ik"], "j", set("ik")), "GEMV/EINSUM"), # Matrix-vector
|
||||
((["ijj", "jk"], "ik", set("j")), False), # Double index
|
||||
((["ijk", "j"], "ij", set()), False), # Index sum 1
|
||||
((["ij", "ij"], "ij", set()), False), # Index sum 2
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("inp,benchmark", blas_tests)
|
||||
def test_can_blas(inp: Any, benchmark: bool) -> None:
|
||||
result = blas.can_blas(*inp)
|
||||
assert result == benchmark
|
||||
|
||||
|
||||
def test_blas_out() -> None:
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
a = np.random.rand(4, 4)
|
||||
b = np.random.rand(4, 4)
|
||||
c = np.random.rand(4, 4)
|
||||
d = np.empty((4, 4))
|
||||
|
||||
contract("ij,jk->ik", a, b, out=d)
|
||||
np.testing.assert_allclose(d, np.dot(a, b))
|
||||
assert np.allclose(d, np.dot(a, b))
|
||||
|
||||
contract("ij,jk,kl->il", a, b, c, out=d)
|
||||
np.testing.assert_allclose(d, np.dot(a, b).dot(c))
|
||||
@@ -0,0 +1,279 @@
|
||||
"""
|
||||
Tets a series of opt_einsum contraction paths to ensure the results are the same for different paths
|
||||
"""
|
||||
|
||||
from typing import Any, List
|
||||
|
||||
import pytest
|
||||
|
||||
from opt_einsum import contract, contract_expression, contract_path
|
||||
from opt_einsum.paths import _PATH_OPTIONS, linear_to_ssa, ssa_to_linear
|
||||
from opt_einsum.testing import build_views, rand_equation
|
||||
from opt_einsum.typing import OptimizeKind
|
||||
|
||||
# NumPy is required for the majority of this file
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
|
||||
tests = [
|
||||
# Test scalar-like operations
|
||||
"a,->a",
|
||||
"ab,->ab",
|
||||
",ab,->ab",
|
||||
",,->",
|
||||
# Test hadamard-like products
|
||||
"a,ab,abc->abc",
|
||||
"a,b,ab->ab",
|
||||
# Test index-transformations
|
||||
"ea,fb,gc,hd,abcd->efgh",
|
||||
"ea,fb,abcd,gc,hd->efgh",
|
||||
"abcd,ea,fb,gc,hd->efgh",
|
||||
# Test complex contractions
|
||||
"acdf,jbje,gihb,hfac,gfac,gifabc,hfac",
|
||||
"acdf,jbje,gihb,hfac,gfac,gifabc,hfac",
|
||||
"cd,bdhe,aidb,hgca,gc,hgibcd,hgac",
|
||||
"abhe,hidj,jgba,hiab,gab",
|
||||
"bde,cdh,agdb,hica,ibd,hgicd,hiac",
|
||||
"chd,bde,agbc,hiad,hgc,hgi,hiad",
|
||||
"chd,bde,agbc,hiad,bdi,cgh,agdb",
|
||||
"bdhe,acad,hiab,agac,hibd",
|
||||
# Test collapse
|
||||
"ab,ab,c->",
|
||||
"ab,ab,c->c",
|
||||
"ab,ab,cd,cd->",
|
||||
"ab,ab,cd,cd->ac",
|
||||
"ab,ab,cd,cd->cd",
|
||||
"ab,ab,cd,cd,ef,ef->",
|
||||
# Test outer prodcuts
|
||||
"ab,cd,ef->abcdef",
|
||||
"ab,cd,ef->acdf",
|
||||
"ab,cd,de->abcde",
|
||||
"ab,cd,de->be",
|
||||
"ab,bcd,cd->abcd",
|
||||
"ab,bcd,cd->abd",
|
||||
# Random test cases that have previously failed
|
||||
"eb,cb,fb->cef",
|
||||
"dd,fb,be,cdb->cef",
|
||||
"bca,cdb,dbf,afc->",
|
||||
"dcc,fce,ea,dbf->ab",
|
||||
"fdf,cdd,ccd,afe->ae",
|
||||
"abcd,ad",
|
||||
"ed,fcd,ff,bcf->be",
|
||||
"baa,dcf,af,cde->be",
|
||||
"bd,db,eac->ace",
|
||||
"fff,fae,bef,def->abd",
|
||||
"efc,dbc,acf,fd->abe",
|
||||
# Inner products
|
||||
"ab,ab",
|
||||
"ab,ba",
|
||||
"abc,abc",
|
||||
"abc,bac",
|
||||
"abc,cba",
|
||||
# GEMM test cases
|
||||
"ab,bc",
|
||||
"ab,cb",
|
||||
"ba,bc",
|
||||
"ba,cb",
|
||||
"abcd,cd",
|
||||
"abcd,ab",
|
||||
"abcd,cdef",
|
||||
"abcd,cdef->feba",
|
||||
"abcd,efdc",
|
||||
# Inner than dot
|
||||
"aab,bc->ac",
|
||||
"ab,bcc->ac",
|
||||
"aab,bcc->ac",
|
||||
"baa,bcc->ac",
|
||||
"aab,ccb->ac",
|
||||
# Randomly build test caes
|
||||
"aab,fa,df,ecc->bde",
|
||||
"ecb,fef,bad,ed->ac",
|
||||
"bcf,bbb,fbf,fc->",
|
||||
"bb,ff,be->e",
|
||||
"bcb,bb,fc,fff->",
|
||||
"fbb,dfd,fc,fc->",
|
||||
"afd,ba,cc,dc->bf",
|
||||
"adb,bc,fa,cfc->d",
|
||||
"bbd,bda,fc,db->acf",
|
||||
"dba,ead,cad->bce",
|
||||
"aef,fbc,dca->bde",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("optimize", (True, False, None))
|
||||
def test_contract_plain_types(optimize: OptimizeKind) -> None:
|
||||
expr = "ij,jk,kl->il"
|
||||
ops = [np.random.rand(2, 2), np.random.rand(2, 2), np.random.rand(2, 2)]
|
||||
|
||||
path = contract_path(expr, *ops, optimize=optimize)
|
||||
assert len(path) == 2
|
||||
|
||||
result = contract(expr, *ops, optimize=optimize)
|
||||
assert result.shape == (2, 2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("string", tests)
|
||||
@pytest.mark.parametrize("optimize", _PATH_OPTIONS)
|
||||
def test_compare(optimize: OptimizeKind, string: str) -> None:
|
||||
views = build_views(string)
|
||||
|
||||
ein = contract(string, *views, optimize=False, use_blas=False)
|
||||
opt = contract(string, *views, optimize=optimize, use_blas=False)
|
||||
assert np.allclose(ein, opt)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("string", tests)
|
||||
def test_drop_in_replacement(string: str) -> None:
|
||||
views = build_views(string)
|
||||
opt = contract(string, *views)
|
||||
assert np.allclose(opt, np.einsum(string, *views))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("string", tests)
|
||||
@pytest.mark.parametrize("optimize", _PATH_OPTIONS)
|
||||
def test_compare_greek(optimize: OptimizeKind, string: str) -> None:
|
||||
views = build_views(string)
|
||||
|
||||
ein = contract(string, *views, optimize=False, use_blas=False)
|
||||
|
||||
# convert to greek
|
||||
string = "".join(chr(ord(c) + 848) if c not in ",->." else c for c in string)
|
||||
|
||||
opt = contract(string, *views, optimize=optimize, use_blas=False)
|
||||
assert np.allclose(ein, opt)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("string", tests)
|
||||
@pytest.mark.parametrize("optimize", _PATH_OPTIONS)
|
||||
def test_compare_blas(optimize: OptimizeKind, string: str) -> None:
|
||||
views = build_views(string)
|
||||
|
||||
ein = contract(string, *views, optimize=False)
|
||||
opt = contract(string, *views, optimize=optimize)
|
||||
assert np.allclose(ein, opt)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("string", tests)
|
||||
@pytest.mark.parametrize("optimize", _PATH_OPTIONS)
|
||||
def test_compare_blas_greek(optimize: OptimizeKind, string: str) -> None:
|
||||
views = build_views(string)
|
||||
|
||||
ein = contract(string, *views, optimize=False)
|
||||
|
||||
# convert to greek
|
||||
string = "".join(chr(ord(c) + 848) if c not in ",->." else c for c in string)
|
||||
|
||||
opt = contract(string, *views, optimize=optimize)
|
||||
assert np.allclose(ein, opt)
|
||||
|
||||
|
||||
def test_some_non_alphabet_maintains_order() -> None:
|
||||
# 'c beta a' should automatically go to -> 'a c beta'
|
||||
string = "c" + chr(ord("b") + 848) + "a"
|
||||
# but beta will be temporarily replaced with 'b' for which 'cba->abc'
|
||||
# so check manual output kicks in:
|
||||
x = np.random.rand(2, 3, 4)
|
||||
assert np.allclose(contract(string, x), contract("cxa", x))
|
||||
|
||||
|
||||
def test_printing():
|
||||
string = "bbd,bda,fc,db->acf"
|
||||
views = build_views(string)
|
||||
|
||||
ein = contract_path(string, *views)
|
||||
assert len(str(ein[1])) == 728
|
||||
|
||||
|
||||
@pytest.mark.parametrize("string", tests)
|
||||
@pytest.mark.parametrize("optimize", _PATH_OPTIONS)
|
||||
@pytest.mark.parametrize("use_blas", [False, True])
|
||||
@pytest.mark.parametrize("out_spec", [False, True])
|
||||
def test_contract_expressions(string: str, optimize: OptimizeKind, use_blas: bool, out_spec: bool) -> None:
|
||||
views = build_views(string)
|
||||
shapes = [view.shape if hasattr(view, "shape") else () for view in views]
|
||||
expected = contract(string, *views, optimize=False, use_blas=False)
|
||||
|
||||
expr = contract_expression(string, *shapes, optimize=optimize, use_blas=use_blas)
|
||||
|
||||
if out_spec and ("->" in string) and (string[-2:] != "->"):
|
||||
(out,) = build_views(string.split("->")[1])
|
||||
expr(*views, out=out)
|
||||
else:
|
||||
out = expr(*views)
|
||||
|
||||
assert np.allclose(out, expected)
|
||||
|
||||
# check representations
|
||||
assert string in expr.__repr__()
|
||||
assert string in expr.__str__()
|
||||
|
||||
|
||||
def test_contract_expression_interleaved_input() -> None:
|
||||
x, y, z = (np.random.randn(2, 2) for _ in "xyz")
|
||||
expected = np.einsum(x, [0, 1], y, [1, 2], z, [2, 3], [3, 0])
|
||||
xshp, yshp, zshp = ((2, 2) for _ in "xyz")
|
||||
expr = contract_expression(xshp, [0, 1], yshp, [1, 2], zshp, [2, 3], [3, 0])
|
||||
out = expr(x, y, z)
|
||||
assert np.allclose(out, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"string,constants",
|
||||
[
|
||||
("hbc,bdef,cdkj,ji,ikeh,lfo", [1, 2, 3, 4]),
|
||||
("bdef,cdkj,ji,ikeh,hbc,lfo", [0, 1, 2, 3]),
|
||||
("hbc,bdef,cdkj,ji,ikeh,lfo", [1, 2, 3, 4]),
|
||||
("hbc,bdef,cdkj,ji,ikeh,lfo", [1, 2, 3, 4]),
|
||||
("ijab,acd,bce,df,ef->ji", [1, 2, 3, 4]),
|
||||
("ab,cd,ad,cb", [1, 3]),
|
||||
("ab,bc,cd", [0, 1]),
|
||||
],
|
||||
)
|
||||
def test_contract_expression_with_constants(string: str, constants: List[int]) -> None:
|
||||
views = build_views(string)
|
||||
expected = contract(string, *views, optimize=False, use_blas=False)
|
||||
|
||||
shapes = [view.shape if hasattr(view, "shape") else () for view in views]
|
||||
|
||||
expr_args: List[Any] = []
|
||||
ctrc_args = []
|
||||
for i, (shape, view) in enumerate(zip(shapes, views)):
|
||||
if i in constants:
|
||||
expr_args.append(view)
|
||||
else:
|
||||
expr_args.append(shape)
|
||||
ctrc_args.append(view)
|
||||
|
||||
expr = contract_expression(string, *expr_args, constants=constants)
|
||||
out = expr(*ctrc_args)
|
||||
assert np.allclose(expected, out)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("optimize", ["greedy", "optimal"])
|
||||
@pytest.mark.parametrize("n", [4, 5])
|
||||
@pytest.mark.parametrize("reg", [2, 3])
|
||||
@pytest.mark.parametrize("n_out", [0, 2, 4])
|
||||
@pytest.mark.parametrize("global_dim", [False, True])
|
||||
def test_rand_equation(optimize: OptimizeKind, n: int, reg: int, n_out: int, global_dim: bool) -> None:
|
||||
eq, _, size_dict = rand_equation(n, reg, n_out, d_min=2, d_max=5, seed=42, return_size_dict=True)
|
||||
views = build_views(eq, size_dict)
|
||||
|
||||
expected = contract(eq, *views, optimize=False)
|
||||
actual = contract(eq, *views, optimize=optimize)
|
||||
|
||||
assert np.allclose(expected, actual)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("equation", tests)
|
||||
def test_linear_vs_ssa(equation: str) -> None:
|
||||
views = build_views(equation)
|
||||
linear_path, _ = contract_path(equation, *views)
|
||||
ssa_path = linear_to_ssa(linear_path)
|
||||
linear_path2 = ssa_to_linear(ssa_path)
|
||||
assert linear_path2 == linear_path
|
||||
|
||||
|
||||
def test_contract_path_supply_shapes() -> None:
|
||||
eq = "ab,bc,cd"
|
||||
shps = [(2, 3), (3, 4), (4, 5)]
|
||||
contract_path(eq, *shps, shapes=True)
|
||||
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
Tets a series of opt_einsum contraction paths to ensure the results are the same for different paths
|
||||
"""
|
||||
|
||||
from typing import Any, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from opt_einsum import contract, contract_expression, contract_path
|
||||
from opt_einsum.typing import PathType
|
||||
|
||||
# NumPy is required for the majority of this file
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
|
||||
def test_contract_expression_checks() -> None:
|
||||
# check optimize needed
|
||||
with pytest.raises(ValueError):
|
||||
contract_expression("ab,bc->ac", (2, 3), (3, 4), optimize=False)
|
||||
|
||||
# check sizes are still checked
|
||||
with pytest.raises(ValueError):
|
||||
contract_expression("ab,bc->ac", (2, 3), (3, 4), (42, 42))
|
||||
|
||||
# check if out given
|
||||
out = np.empty((2, 4))
|
||||
with pytest.raises(ValueError):
|
||||
contract_expression("ab,bc->ac", (2, 3), (3, 4), out=out)
|
||||
|
||||
# check still get errors when wrong ranks supplied to expression
|
||||
expr = contract_expression("ab,bc->ac", (2, 3), (3, 4))
|
||||
|
||||
# too few arguments
|
||||
with pytest.raises(ValueError) as err:
|
||||
expr(np.random.rand(2, 3))
|
||||
assert "`ContractExpression` takes exactly 2" in str(err.value)
|
||||
|
||||
# too many arguments
|
||||
with pytest.raises(ValueError) as err:
|
||||
expr(np.random.rand(2, 3), np.random.rand(2, 3), np.random.rand(2, 3))
|
||||
assert "`ContractExpression` takes exactly 2" in str(err.value)
|
||||
|
||||
# wrong shapes
|
||||
with pytest.raises(ValueError) as err:
|
||||
expr(np.random.rand(2, 3, 4), np.random.rand(3, 4))
|
||||
assert "Internal error while evaluating `ContractExpression`" in str(err.value)
|
||||
with pytest.raises(ValueError) as err:
|
||||
expr(np.random.rand(2, 4), np.random.rand(3, 4, 5))
|
||||
assert "Internal error while evaluating `ContractExpression`" in str(err.value)
|
||||
with pytest.raises(ValueError) as err:
|
||||
expr(np.random.rand(2, 3), np.random.rand(3, 4), out=np.random.rand(2, 4, 6))
|
||||
assert "Internal error while evaluating `ContractExpression`" in str(err.value)
|
||||
|
||||
# should only be able to specify out
|
||||
with pytest.raises(TypeError) as err_type:
|
||||
expr(np.random.rand(2, 3), np.random.rand(3, 4), order="F") # type: ignore
|
||||
assert "got an unexpected keyword" in str(err_type.value)
|
||||
|
||||
|
||||
def test_broadcasting_contraction() -> None:
|
||||
a = np.random.rand(1, 5, 4)
|
||||
b = np.random.rand(4, 6)
|
||||
c = np.random.rand(5, 6)
|
||||
d = np.random.rand(10)
|
||||
|
||||
ein_scalar = contract("ijk,kl,jl", a, b, c, optimize=False)
|
||||
opt_scalar = contract("ijk,kl,jl", a, b, c, optimize=True)
|
||||
assert np.allclose(ein_scalar, opt_scalar)
|
||||
|
||||
result = ein_scalar * d
|
||||
|
||||
ein = contract("ijk,kl,jl,i->i", a, b, c, d, optimize=False)
|
||||
opt = contract("ijk,kl,jl,i->i", a, b, c, d, optimize=True)
|
||||
|
||||
assert np.allclose(ein, result)
|
||||
assert np.allclose(opt, result)
|
||||
|
||||
|
||||
def test_broadcasting_contraction2() -> None:
|
||||
a = np.random.rand(1, 1, 5, 4)
|
||||
b = np.random.rand(4, 6)
|
||||
c = np.random.rand(5, 6)
|
||||
d = np.random.rand(7, 7)
|
||||
|
||||
ein_scalar = contract("abjk,kl,jl", a, b, c, optimize=False)
|
||||
opt_scalar = contract("abjk,kl,jl", a, b, c, optimize=True)
|
||||
assert np.allclose(ein_scalar, opt_scalar)
|
||||
|
||||
result = ein_scalar * d
|
||||
|
||||
ein = contract("abjk,kl,jl,ab->ab", a, b, c, d, optimize=False)
|
||||
opt = contract("abjk,kl,jl,ab->ab", a, b, c, d, optimize=True)
|
||||
|
||||
assert np.allclose(ein, result)
|
||||
assert np.allclose(opt, result)
|
||||
|
||||
|
||||
def test_broadcasting_contraction3() -> None:
|
||||
a = np.random.rand(1, 5, 4)
|
||||
b = np.random.rand(4, 1, 6)
|
||||
c = np.random.rand(5, 6)
|
||||
d = np.random.rand(7, 7)
|
||||
|
||||
ein = contract("ajk,kbl,jl,ab->ab", a, b, c, d, optimize=False)
|
||||
opt = contract("ajk,kbl,jl,ab->ab", a, b, c, d, optimize=True)
|
||||
|
||||
assert np.allclose(ein, opt)
|
||||
|
||||
|
||||
def test_broadcasting_contraction4() -> None:
|
||||
a = np.arange(64).reshape(2, 4, 8)
|
||||
ein = contract("obk,ijk->ioj", a, a, optimize=False)
|
||||
opt = contract("obk,ijk->ioj", a, a, optimize=True)
|
||||
|
||||
assert np.allclose(ein, opt)
|
||||
|
||||
|
||||
def test_can_blas_on_healed_broadcast_dimensions() -> None:
|
||||
expr = contract_expression("ab,bc,bd->acd", (5, 4), (1, 5), (4, 20))
|
||||
# first contraction involves broadcasting
|
||||
assert expr.contraction_list[0][2] == "bc,ab->bca"
|
||||
assert expr.contraction_list[0][-1] is False
|
||||
# but then is healed GEMM is usable
|
||||
assert expr.contraction_list[1][2] == "bca,bd->acd"
|
||||
assert expr.contraction_list[1][-1] == "GEMM"
|
||||
|
||||
|
||||
def test_pathinfo_for_empty_contraction() -> None:
|
||||
eq = "->"
|
||||
arrays = (1.0,)
|
||||
path: PathType = []
|
||||
_, info = contract_path(eq, *arrays, optimize=path)
|
||||
# some info is built lazily, so check repr
|
||||
assert repr(info)
|
||||
assert info.largest_intermediate == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"expression, operands",
|
||||
[
|
||||
[",,->", (5, 5.0, 2.0j)],
|
||||
["ab,->", ([[5, 5], [2.0, 1]], 2.0j)],
|
||||
["ab,bc->ac", ([[5, 5], [2.0, 1]], [[2.0, 1], [3.0, 4]])],
|
||||
["ab,->", ([[5, 5], [2.0, 1]], True)],
|
||||
],
|
||||
)
|
||||
def test_contract_with_assumed_shapes(expression: str, operands: Tuple[Any]) -> None:
|
||||
"""Test that we can contract with assumed shapes, and that the output is correct. This is required as we need to infer intermediate shape sizes."""
|
||||
|
||||
benchmark = np.einsum(expression, *operands)
|
||||
result = contract(expression, *operands, optimize=True)
|
||||
assert np.allclose(benchmark, result)
|
||||
@@ -0,0 +1,279 @@
|
||||
"""
|
||||
Tests the input parsing for opt_einsum. Duplicates the np.einsum input tests.
|
||||
"""
|
||||
|
||||
from typing import Any, List
|
||||
|
||||
import pytest
|
||||
|
||||
from opt_einsum import contract, contract_path
|
||||
from opt_einsum.typing import ArrayType
|
||||
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
|
||||
def build_views(string: str) -> List[ArrayType]:
|
||||
"""Builds random numpy arrays for testing by using a fixed size dictionary and an input string."""
|
||||
|
||||
chars = "abcdefghij"
|
||||
sizes_array = np.array([2, 3, 4, 5, 4, 3, 2, 6, 5, 4])
|
||||
sizes = dict(zip(chars, sizes_array))
|
||||
|
||||
views = []
|
||||
|
||||
string = string.replace("...", "ij")
|
||||
|
||||
terms = string.split("->")[0].split(",")
|
||||
for term in terms:
|
||||
dims = [sizes[x] for x in term]
|
||||
views.append(np.random.rand(*dims))
|
||||
return views
|
||||
|
||||
|
||||
def test_type_errors() -> None:
|
||||
# subscripts must be a string
|
||||
with pytest.raises(TypeError):
|
||||
contract(0, 0)
|
||||
|
||||
# out parameter must be an array
|
||||
with pytest.raises(TypeError):
|
||||
contract("", 0, out="test")
|
||||
|
||||
# order parameter must be a valid order
|
||||
# changed in Numpy 1.19, see https://github.com/numpy/numpy/commit/35b0a051c19265f5643f6011ee11e31d30c8bc4c
|
||||
with pytest.raises((TypeError, ValueError)):
|
||||
contract("", 0, order="W") # type: ignore
|
||||
|
||||
# casting parameter must be a valid casting
|
||||
with pytest.raises(ValueError):
|
||||
contract("", 0, casting="blah") # type: ignore
|
||||
|
||||
# dtype parameter must be a valid dtype
|
||||
with pytest.raises(TypeError):
|
||||
contract("", 0, dtype="bad_data_type")
|
||||
|
||||
# other keyword arguments are rejected
|
||||
with pytest.raises(TypeError):
|
||||
contract("", 0, bad_arg=0)
|
||||
|
||||
# issue 4528 revealed a segfault with this call
|
||||
with pytest.raises(TypeError):
|
||||
contract(*(None,) * 63)
|
||||
|
||||
# Cannot have two ->
|
||||
with pytest.raises(ValueError):
|
||||
contract("->,->", 0, 5)
|
||||
|
||||
# Undefined symbol lhs
|
||||
with pytest.raises(ValueError):
|
||||
contract("&,a->", 0, 5)
|
||||
|
||||
# Undefined symbol rhs
|
||||
with pytest.raises(ValueError):
|
||||
contract("a,a->&", 0, 5)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
contract("a,a->&", 0, 5)
|
||||
|
||||
# Catch ellipsis errors
|
||||
string = "...a->...a"
|
||||
views = build_views(string)
|
||||
|
||||
# Subscript list must contain Ellipsis or (hashable && comparable) object
|
||||
with pytest.raises(TypeError):
|
||||
contract(views[0], [Ellipsis, 0], [Ellipsis, ["a"]])
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
contract(views[0], [Ellipsis, {}], [Ellipsis, "a"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("contract_fn", [contract, contract_path])
|
||||
def test_value_errors(contract_fn: Any) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
contract_fn("")
|
||||
|
||||
# subscripts must be a string
|
||||
with pytest.raises(TypeError):
|
||||
contract_fn(0, 0)
|
||||
|
||||
# invalid subscript character
|
||||
with pytest.raises(ValueError):
|
||||
contract_fn("i%...", [0, 0])
|
||||
with pytest.raises(ValueError):
|
||||
contract_fn("...j$", [0, 0])
|
||||
with pytest.raises(ValueError):
|
||||
contract_fn("i->&", [0, 0])
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
contract_fn("")
|
||||
# number of operands must match count in subscripts string
|
||||
with pytest.raises(ValueError):
|
||||
contract_fn("", 0, 0)
|
||||
with pytest.raises(ValueError):
|
||||
contract_fn(",", 0, [0], [0])
|
||||
with pytest.raises(ValueError):
|
||||
contract_fn(",", [0])
|
||||
|
||||
# can't have more subscripts than dimensions in the operand
|
||||
with pytest.raises(ValueError):
|
||||
contract_fn("i", 0)
|
||||
with pytest.raises(ValueError):
|
||||
contract_fn("ij", [0, 0])
|
||||
with pytest.raises(ValueError):
|
||||
contract_fn("...i", 0)
|
||||
with pytest.raises(ValueError):
|
||||
contract_fn("i...j", [0, 0])
|
||||
with pytest.raises(ValueError):
|
||||
contract_fn("i...", 0)
|
||||
with pytest.raises(ValueError):
|
||||
contract_fn("ij...", [0, 0])
|
||||
|
||||
# invalid ellipsis
|
||||
with pytest.raises(ValueError):
|
||||
contract_fn("i..", [0, 0])
|
||||
with pytest.raises(ValueError):
|
||||
contract_fn(".i...", [0, 0])
|
||||
with pytest.raises(ValueError):
|
||||
contract_fn("j->..j", [0, 0])
|
||||
with pytest.raises(ValueError):
|
||||
contract_fn("j->.j...", [0, 0])
|
||||
|
||||
# invalid subscript character
|
||||
with pytest.raises(ValueError):
|
||||
contract_fn("i%...", [0, 0])
|
||||
with pytest.raises(ValueError):
|
||||
contract_fn("...j$", [0, 0])
|
||||
with pytest.raises(ValueError):
|
||||
contract_fn("i->&", [0, 0])
|
||||
|
||||
# output subscripts must appear in input
|
||||
with pytest.raises(ValueError):
|
||||
contract_fn("i->ij", [0, 0])
|
||||
|
||||
# output subscripts may only be specified once
|
||||
with pytest.raises(ValueError):
|
||||
contract_fn("ij->jij", [[0, 0], [0, 0]])
|
||||
|
||||
# dimensions much match when being collapsed
|
||||
with pytest.raises(ValueError):
|
||||
contract_fn("ii", np.arange(6).reshape(2, 3))
|
||||
with pytest.raises(ValueError):
|
||||
contract_fn("ii->i", np.arange(6).reshape(2, 3))
|
||||
|
||||
# broadcasting to new dimensions must be enabled explicitly
|
||||
with pytest.raises(ValueError):
|
||||
contract_fn("i", np.arange(6).reshape(2, 3))
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
contract_fn("ij->ij", [[0, 1], [0, 1]], bad_kwarg=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"string",
|
||||
[
|
||||
# Ellipse
|
||||
"...a->...",
|
||||
"a...->...",
|
||||
"a...a->...a",
|
||||
"...,...",
|
||||
"a,b",
|
||||
"...a,...b",
|
||||
],
|
||||
)
|
||||
def test_compare(string: str) -> None:
|
||||
views = build_views(string)
|
||||
|
||||
ein = contract(string, *views, optimize=False)
|
||||
opt = contract(string, *views)
|
||||
assert np.allclose(ein, opt)
|
||||
|
||||
opt = contract(string, *views, optimize="optimal")
|
||||
assert np.allclose(ein, opt)
|
||||
|
||||
|
||||
def test_ellipse_input1() -> None:
|
||||
string = "...a->..."
|
||||
views = build_views(string)
|
||||
|
||||
ein = contract(string, *views, optimize=False)
|
||||
opt = contract(views[0], [Ellipsis, 0], [Ellipsis])
|
||||
assert np.allclose(ein, opt)
|
||||
|
||||
|
||||
def test_ellipse_input2() -> None:
|
||||
string = "...a"
|
||||
views = build_views(string)
|
||||
|
||||
ein = contract(string, *views, optimize=False)
|
||||
opt = contract(views[0], [Ellipsis, 0])
|
||||
assert np.allclose(ein, opt)
|
||||
|
||||
|
||||
def test_ellipse_input3() -> None:
|
||||
string = "...a->...a"
|
||||
views = build_views(string)
|
||||
|
||||
ein = contract(string, *views, optimize=False)
|
||||
opt = contract(views[0], [Ellipsis, 0], [Ellipsis, 0])
|
||||
assert np.allclose(ein, opt)
|
||||
|
||||
|
||||
def test_ellipse_input4() -> None:
|
||||
string = "...b,...a->..."
|
||||
views = build_views(string)
|
||||
|
||||
ein = contract(string, *views, optimize=False)
|
||||
opt = contract(views[0], [Ellipsis, 1], views[1], [Ellipsis, 0], [Ellipsis])
|
||||
assert np.allclose(ein, opt)
|
||||
|
||||
|
||||
def test_singleton_dimension_broadcast() -> None:
|
||||
# singleton dimensions broadcast (gh-10343)
|
||||
p = np.ones((10, 2))
|
||||
q = np.ones((1, 2))
|
||||
|
||||
ein = contract("ij,ij->j", p, q, optimize=False)
|
||||
opt = contract("ij,ij->j", p, q, optimize=True)
|
||||
assert np.allclose(ein, opt)
|
||||
assert np.allclose(opt, [10.0, 10.0])
|
||||
|
||||
p = np.ones((1, 5))
|
||||
q = np.ones((5, 5))
|
||||
|
||||
for optimize in (True, False):
|
||||
res1 = (contract("...ij,...jk->...ik", p, p, optimize=optimize),)
|
||||
res2 = contract("...ij,...jk->...ik", p, q, optimize=optimize)
|
||||
assert np.allclose(res1, res2)
|
||||
assert np.allclose(res2, np.full((1, 5), 5))
|
||||
|
||||
|
||||
def test_large_int_input_format() -> None:
|
||||
string = "ab,bc,cd"
|
||||
x, y, z = build_views(string)
|
||||
string_output = contract(string, x, y, z)
|
||||
int_output = contract(x, (1000, 1001), y, (1001, 1002), z, (1002, 1003))
|
||||
assert np.allclose(string_output, int_output)
|
||||
for i in range(10):
|
||||
transpose_output = contract(x, (i + 1, i))
|
||||
assert np.allclose(transpose_output, x.T)
|
||||
|
||||
|
||||
def test_hashable_object_input_format() -> None:
|
||||
string = "ab,bc,cd"
|
||||
x, y, z = build_views(string)
|
||||
string_output = contract(string, x, y, z)
|
||||
hash_output1 = contract(x, ("left", "bond1"), y, ("bond1", "bond2"), z, ("bond2", "right"))
|
||||
hash_output2 = contract(
|
||||
x,
|
||||
("left", "bond1"),
|
||||
y,
|
||||
("bond1", "bond2"),
|
||||
z,
|
||||
("bond2", "right"),
|
||||
("left", "right"),
|
||||
)
|
||||
assert np.allclose(string_output, hash_output1)
|
||||
assert np.allclose(hash_output1, hash_output2)
|
||||
for i in range(1, 10):
|
||||
transpose_output = contract(x, ("b" * i, "a" * i))
|
||||
assert np.allclose(transpose_output, x.T)
|
||||
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
Directly tests various parser utility functions.
|
||||
"""
|
||||
|
||||
from typing import Any, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from opt_einsum.parser import get_shape, get_symbol, parse_einsum_input
|
||||
from opt_einsum.testing import build_arrays_from_tuples
|
||||
|
||||
|
||||
def test_get_symbol() -> None:
|
||||
assert get_symbol(2) == "c"
|
||||
assert get_symbol(200000) == "\U00031540"
|
||||
# Ensure we skip surrogates '[\uD800-\uDFFF]'
|
||||
assert get_symbol(55295) == "\ud88b"
|
||||
assert get_symbol(55296) == "\ue000"
|
||||
assert get_symbol(57343) == "\ue7ff"
|
||||
|
||||
|
||||
def test_parse_einsum_input() -> None:
|
||||
eq = "ab,bc,cd"
|
||||
ops = build_arrays_from_tuples([(2, 3), (3, 4), (4, 5)])
|
||||
input_subscripts, output_subscript, operands = parse_einsum_input([eq, *ops])
|
||||
assert input_subscripts == eq
|
||||
assert output_subscript == "ad"
|
||||
assert operands == ops
|
||||
|
||||
|
||||
def test_parse_einsum_input_shapes_error() -> None:
|
||||
eq = "ab,bc,cd"
|
||||
ops = build_arrays_from_tuples([(2, 3), (3, 4), (4, 5)])
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_ = parse_einsum_input([eq, *ops], shapes=True)
|
||||
|
||||
|
||||
def test_parse_einsum_input_shapes() -> None:
|
||||
eq = "ab,bc,cd"
|
||||
shapes = [(2, 3), (3, 4), (4, 5)]
|
||||
input_subscripts, output_subscript, operands = parse_einsum_input([eq, *shapes], shapes=True)
|
||||
assert input_subscripts == eq
|
||||
assert output_subscript == "ad"
|
||||
assert shapes == operands
|
||||
|
||||
|
||||
def test_parse_with_ellisis() -> None:
|
||||
eq = "...a,ab"
|
||||
shapes = [(2, 3), (3, 4)]
|
||||
input_subscripts, output_subscript, operands = parse_einsum_input([eq, *shapes], shapes=True)
|
||||
assert input_subscripts == "da,ab"
|
||||
assert output_subscript == "db"
|
||||
assert shapes == operands
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"array, shape",
|
||||
[
|
||||
[[5], (1,)],
|
||||
[[5, 5], (2,)],
|
||||
[(5, 5), (2,)],
|
||||
[[[[[[5, 2]]]]], (1, 1, 1, 1, 2)],
|
||||
[[[[[["abcdef", "b"]]]]], (1, 1, 1, 1, 2)],
|
||||
["A", ()],
|
||||
[b"A", ()],
|
||||
[True, ()],
|
||||
[5, ()],
|
||||
[5.0, ()],
|
||||
[5.0 + 0j, ()],
|
||||
],
|
||||
)
|
||||
def test_get_shapes(array: Any, shape: Tuple[int]) -> None:
|
||||
assert get_shape(array) == shape
|
||||
@@ -0,0 +1,534 @@
|
||||
"""
|
||||
Tests the accuracy of the opt_einsum paths in addition to unit tests for
|
||||
the various path helper functions.
|
||||
"""
|
||||
|
||||
import itertools
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
import opt_einsum as oe
|
||||
from opt_einsum.testing import build_shapes, rand_equation
|
||||
from opt_einsum.typing import ArrayIndexType, OptimizeKind, PathType, TensorShapeType
|
||||
|
||||
explicit_path_tests = {
|
||||
"GEMM1": (
|
||||
[set("abd"), set("ac"), set("bdc")],
|
||||
set(""),
|
||||
{"a": 1, "b": 2, "c": 3, "d": 4},
|
||||
),
|
||||
"Inner1": (
|
||||
[set("abcd"), set("abc"), set("bc")],
|
||||
set(""),
|
||||
{"a": 5, "b": 2, "c": 3, "d": 4},
|
||||
),
|
||||
}
|
||||
|
||||
# note that these tests have no unique solution due to the chosen dimensions
|
||||
path_edge_tests = [
|
||||
["greedy", "eb,cb,fb->cef", ((0, 2), (0, 1))],
|
||||
["branch-all", "eb,cb,fb->cef", ((0, 2), (0, 1))],
|
||||
["branch-2", "eb,cb,fb->cef", ((0, 2), (0, 1))],
|
||||
["optimal", "eb,cb,fb->cef", ((0, 2), (0, 1))],
|
||||
["dp", "eb,cb,fb->cef", ((1, 2), (0, 1))],
|
||||
["greedy", "dd,fb,be,cdb->cef", ((0, 3), (0, 1), (0, 1))],
|
||||
["branch-all", "dd,fb,be,cdb->cef", ((0, 3), (0, 1), (0, 1))],
|
||||
["branch-2", "dd,fb,be,cdb->cef", ((0, 3), (0, 1), (0, 1))],
|
||||
["optimal", "dd,fb,be,cdb->cef", ((0, 3), (0, 1), (0, 1))],
|
||||
["optimal", "dd,fb,be,cdb->cef", ((0, 3), (0, 1), (0, 1))],
|
||||
["dp", "dd,fb,be,cdb->cef", ((0, 3), (0, 2), (0, 1))],
|
||||
["greedy", "bca,cdb,dbf,afc->", ((1, 2), (0, 2), (0, 1))],
|
||||
["branch-all", "bca,cdb,dbf,afc->", ((1, 2), (0, 2), (0, 1))],
|
||||
["branch-2", "bca,cdb,dbf,afc->", ((1, 2), (0, 2), (0, 1))],
|
||||
["optimal", "bca,cdb,dbf,afc->", ((1, 2), (0, 2), (0, 1))],
|
||||
["dp", "bca,cdb,dbf,afc->", ((1, 2), (1, 2), (0, 1))],
|
||||
["greedy", "dcc,fce,ea,dbf->ab", ((1, 2), (0, 1), (0, 1))],
|
||||
["branch-all", "dcc,fce,ea,dbf->ab", ((1, 2), (0, 2), (0, 1))],
|
||||
["branch-2", "dcc,fce,ea,dbf->ab", ((1, 2), (0, 2), (0, 1))],
|
||||
["optimal", "dcc,fce,ea,dbf->ab", ((1, 2), (0, 2), (0, 1))],
|
||||
["dp", "dcc,fce,ea,dbf->ab", ((1, 2), (0, 2), (0, 1))],
|
||||
]
|
||||
|
||||
# note that these tests have no unique solution due to the chosen dimensions
|
||||
path_scalar_tests = [
|
||||
[
|
||||
"a,->a",
|
||||
1,
|
||||
],
|
||||
["ab,->ab", 1],
|
||||
[",a,->a", 2],
|
||||
[",,a,->a", 3],
|
||||
[",,->", 2],
|
||||
]
|
||||
|
||||
|
||||
def check_path(test_output: PathType, benchmark: PathType, bypass: bool = False) -> bool:
|
||||
if not isinstance(test_output, list):
|
||||
return False
|
||||
|
||||
if len(test_output) != len(benchmark):
|
||||
return False
|
||||
|
||||
ret = True
|
||||
for pos in range(len(test_output)):
|
||||
ret &= isinstance(test_output[pos], tuple)
|
||||
ret &= test_output[pos] == list(benchmark)[pos]
|
||||
return ret
|
||||
|
||||
|
||||
def assert_contract_order(func: Any, test_data: Any, max_size: int, benchmark: PathType) -> None:
|
||||
test_output = func(test_data[0], test_data[1], test_data[2], max_size)
|
||||
assert check_path(test_output, benchmark)
|
||||
|
||||
|
||||
def test_size_by_dict() -> None:
|
||||
sizes_dict = {}
|
||||
for ind, val in zip("abcdez", [2, 5, 9, 11, 13, 0]):
|
||||
sizes_dict[ind] = val
|
||||
|
||||
path_func = oe.helpers.compute_size_by_dict
|
||||
|
||||
assert 1 == path_func("", sizes_dict)
|
||||
assert 2 == path_func("a", sizes_dict)
|
||||
assert 5 == path_func("b", sizes_dict)
|
||||
|
||||
assert 0 == path_func("z", sizes_dict)
|
||||
assert 0 == path_func("az", sizes_dict)
|
||||
assert 0 == path_func("zbc", sizes_dict)
|
||||
|
||||
assert 104 == path_func("aaae", sizes_dict)
|
||||
assert 12870 == path_func("abcde", sizes_dict)
|
||||
|
||||
|
||||
def test_flop_cost() -> None:
|
||||
size_dict = {v: 10 for v in "abcdef"}
|
||||
|
||||
# Loop over an array
|
||||
assert 10 == oe.helpers.flop_count("a", False, 1, size_dict)
|
||||
|
||||
# Hadamard product (*)
|
||||
assert 10 == oe.helpers.flop_count("a", False, 2, size_dict)
|
||||
assert 100 == oe.helpers.flop_count("ab", False, 2, size_dict)
|
||||
|
||||
# Inner product (+, *)
|
||||
assert 20 == oe.helpers.flop_count("a", True, 2, size_dict)
|
||||
assert 200 == oe.helpers.flop_count("ab", True, 2, size_dict)
|
||||
|
||||
# Inner product x3 (+, *, *)
|
||||
assert 30 == oe.helpers.flop_count("a", True, 3, size_dict)
|
||||
|
||||
# GEMM
|
||||
assert 2000 == oe.helpers.flop_count("abc", True, 2, size_dict)
|
||||
|
||||
|
||||
def test_bad_path_option() -> None:
|
||||
with pytest.raises(KeyError):
|
||||
oe.contract("a,b,c", [1], [2], [3], optimize="optimall", shapes=True) # type: ignore
|
||||
|
||||
|
||||
def test_explicit_path() -> None:
|
||||
pytest.importorskip("numpy")
|
||||
x = oe.contract("a,b,c", [1], [2], [3], optimize=[(1, 2), (0, 1)])
|
||||
assert x.item() == 6
|
||||
|
||||
|
||||
def test_path_optimal() -> None:
|
||||
test_func = oe.paths.optimal
|
||||
|
||||
test_data = explicit_path_tests["GEMM1"]
|
||||
assert_contract_order(test_func, test_data, 5000, [(0, 2), (0, 1)])
|
||||
assert_contract_order(test_func, test_data, 0, [(0, 1, 2)])
|
||||
|
||||
|
||||
def test_path_greedy() -> None:
|
||||
test_func = oe.paths.greedy
|
||||
|
||||
test_data = explicit_path_tests["GEMM1"]
|
||||
assert_contract_order(test_func, test_data, 5000, [(0, 2), (0, 1)])
|
||||
assert_contract_order(test_func, test_data, 0, [(0, 1, 2)])
|
||||
|
||||
|
||||
def test_memory_paths() -> None:
|
||||
expression = "abc,bdef,fghj,cem,mhk,ljk->adgl"
|
||||
|
||||
views = build_shapes(expression)
|
||||
|
||||
# Test tiny memory limit
|
||||
path_ret = oe.contract_path(expression, *views, optimize="optimal", memory_limit=5, shapes=True)
|
||||
assert check_path(path_ret[0], [(0, 1, 2, 3, 4, 5)])
|
||||
|
||||
path_ret = oe.contract_path(expression, *views, optimize="greedy", memory_limit=5, shapes=True)
|
||||
assert check_path(path_ret[0], [(0, 1, 2, 3, 4, 5)])
|
||||
|
||||
# Check the possibilities, greedy is capped
|
||||
path_ret = oe.contract_path(expression, *views, optimize="optimal", memory_limit=-1, shapes=True)
|
||||
assert check_path(path_ret[0], [(0, 3), (0, 4), (0, 2), (0, 2), (0, 1)])
|
||||
|
||||
path_ret = oe.contract_path(expression, *views, optimize="greedy", memory_limit=-1, shapes=True)
|
||||
assert check_path(path_ret[0], [(0, 3), (0, 4), (0, 2), (0, 2), (0, 1)])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("alg,expression,order", path_edge_tests)
|
||||
def test_path_edge_cases(alg: OptimizeKind, expression: str, order: PathType) -> None:
|
||||
views = build_shapes(expression)
|
||||
|
||||
# Test tiny memory limit
|
||||
path_ret = oe.contract_path(expression, *views, optimize=alg, shapes=True)
|
||||
assert check_path(path_ret[0], order)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("expression,order", path_scalar_tests)
|
||||
@pytest.mark.parametrize("alg", oe.paths._PATH_OPTIONS)
|
||||
def test_path_scalar_cases(alg: OptimizeKind, expression: str, order: PathType) -> None:
|
||||
views = build_shapes(expression)
|
||||
|
||||
# Test tiny memory limit
|
||||
path_ret = oe.contract_path(expression, *views, optimize=alg, shapes=True)
|
||||
# print(path_ret[0])
|
||||
assert len(path_ret[0]) == order
|
||||
|
||||
|
||||
def test_optimal_edge_cases() -> None:
|
||||
# Edge test5
|
||||
expression = "a,ac,ab,ad,cd,bd,bc->"
|
||||
edge_test4 = build_shapes(expression, dimension_dict={"a": 20, "b": 20, "c": 20, "d": 20})
|
||||
path, _ = oe.contract_path(expression, *edge_test4, optimize="greedy", memory_limit="max_input", shapes=True)
|
||||
assert check_path(path, [(0, 1), (0, 1, 2, 3, 4, 5)])
|
||||
|
||||
path, _ = oe.contract_path(expression, *edge_test4, optimize="optimal", memory_limit="max_input", shapes=True)
|
||||
assert check_path(path, [(0, 1), (0, 1, 2, 3, 4, 5)])
|
||||
|
||||
|
||||
def test_greedy_edge_cases() -> None:
|
||||
expression = "abc,cfd,dbe,efa"
|
||||
dim_dict = {k: 20 for k in expression.replace(",", "")}
|
||||
tensors = build_shapes(expression, dimension_dict=dim_dict)
|
||||
|
||||
path, _ = oe.contract_path(expression, *tensors, optimize="greedy", memory_limit="max_input", shapes=True)
|
||||
assert check_path(path, [(0, 1, 2, 3)])
|
||||
|
||||
path, _ = oe.contract_path(expression, *tensors, optimize="greedy", memory_limit=-1, shapes=True)
|
||||
assert check_path(path, [(0, 1), (0, 2), (0, 1)])
|
||||
|
||||
|
||||
def test_dp_edge_cases_dimension_1() -> None:
|
||||
eq = "nlp,nlq,pl->n"
|
||||
shapes = [(1, 1, 1), (1, 1, 1), (1, 1)]
|
||||
info = oe.contract_path(eq, *shapes, shapes=True, optimize="dp")[1]
|
||||
assert max(info.scale_list) == 3
|
||||
|
||||
|
||||
def test_dp_edge_cases_all_singlet_indices() -> None:
|
||||
eq = "a,bcd,efg->"
|
||||
shapes = [(2,), (2, 2, 2), (2, 2, 2)]
|
||||
info = oe.contract_path(eq, *shapes, shapes=True, optimize="dp")[1]
|
||||
assert max(info.scale_list) == 3
|
||||
|
||||
|
||||
def test_custom_dp_can_optimize_for_outer_products() -> None:
|
||||
eq = "a,b,abc->c"
|
||||
|
||||
da, db, dc = 2, 2, 3
|
||||
shapes = [(da,), (db,), (da, db, dc)]
|
||||
|
||||
opt1 = oe.DynamicProgramming(search_outer=False)
|
||||
opt2 = oe.DynamicProgramming(search_outer=True)
|
||||
|
||||
info1 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt1)[1]
|
||||
info2 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt2)[1]
|
||||
|
||||
assert info2.opt_cost < info1.opt_cost
|
||||
|
||||
|
||||
def test_custom_dp_can_optimize_for_size() -> None:
|
||||
eq, shapes = rand_equation(10, 4, seed=43)
|
||||
|
||||
opt1 = oe.DynamicProgramming(minimize="flops")
|
||||
opt2 = oe.DynamicProgramming(minimize="size")
|
||||
|
||||
info1 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt1)[1]
|
||||
info2 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt2)[1]
|
||||
|
||||
assert info1.opt_cost < info2.opt_cost
|
||||
assert info1.largest_intermediate > info2.largest_intermediate
|
||||
|
||||
|
||||
def test_custom_dp_can_set_cost_cap() -> None:
|
||||
eq, shapes = rand_equation(5, 3, seed=42)
|
||||
opt1 = oe.DynamicProgramming(cost_cap=True)
|
||||
opt2 = oe.DynamicProgramming(cost_cap=False)
|
||||
opt3 = oe.DynamicProgramming(cost_cap=100)
|
||||
info1 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt1)[1]
|
||||
info2 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt2)[1]
|
||||
info3 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt3)[1]
|
||||
assert info1.opt_cost == info2.opt_cost == info3.opt_cost
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"minimize,cost,width,path",
|
||||
[
|
||||
("flops", 663054, 18900, [(4, 5), (2, 5), (2, 7), (5, 6), (1, 5), (1, 4), (0, 3), (0, 2), (0, 1)]),
|
||||
("size", 1114440, 2016, [(2, 7), (3, 8), (3, 7), (2, 6), (1, 5), (1, 4), (1, 3), (1, 2), (0, 1)]),
|
||||
("write", 983790, 2016, [(0, 8), (3, 4), (1, 4), (5, 6), (1, 5), (0, 4), (0, 3), (1, 2), (0, 1)]),
|
||||
("combo", 973518, 2016, [(4, 5), (2, 5), (6, 7), (2, 6), (1, 5), (1, 4), (0, 3), (0, 2), (0, 1)]),
|
||||
("limit", 983832, 2016, [(2, 7), (3, 4), (0, 4), (3, 6), (2, 5), (0, 4), (0, 3), (1, 2), (0, 1)]),
|
||||
("combo-256", 983790, 2016, [(0, 8), (3, 4), (1, 4), (5, 6), (1, 5), (0, 4), (0, 3), (1, 2), (0, 1)]),
|
||||
("limit-256", 983832, 2016, [(2, 7), (3, 4), (0, 4), (3, 6), (2, 5), (0, 4), (0, 3), (1, 2), (0, 1)]),
|
||||
],
|
||||
)
|
||||
def test_custom_dp_can_set_minimize(minimize: str, cost: int, width: int, path: PathType) -> None:
|
||||
eq, shapes = rand_equation(10, 4, seed=43)
|
||||
opt = oe.DynamicProgramming(minimize=minimize)
|
||||
info = oe.contract_path(eq, *shapes, shapes=True, optimize=opt)[1]
|
||||
assert info.path == path
|
||||
assert info.opt_cost == cost
|
||||
assert info.largest_intermediate == width
|
||||
|
||||
|
||||
def test_dp_errors_when_no_contractions_found() -> None:
|
||||
eq, shapes = rand_equation(10, 3, seed=42)
|
||||
|
||||
# first get the actual minimum cost
|
||||
opt = oe.DynamicProgramming(minimize="size")
|
||||
_, info = oe.contract_path(eq, *shapes, shapes=True, optimize=opt)
|
||||
mincost = info.largest_intermediate
|
||||
|
||||
# check we can still find it without minimizing size explicitly
|
||||
oe.contract_path(eq, *shapes, shapes=True, memory_limit=mincost, optimize="dp")
|
||||
|
||||
# but check just below this threshold raises
|
||||
with pytest.raises(RuntimeError):
|
||||
oe.contract_path(eq, *shapes, shapes=True, memory_limit=mincost - 1, optimize="dp")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("optimize", ["greedy", "branch-2", "branch-all", "optimal", "dp"])
|
||||
def test_can_optimize_outer_products(optimize: OptimizeKind) -> None:
|
||||
a, b, c = ((10, 10) for _ in range(3))
|
||||
d = (10, 2)
|
||||
|
||||
assert oe.contract_path("ab,cd,ef,fg", a, b, c, d, optimize=optimize, shapes=True)[0] == [
|
||||
(2, 3),
|
||||
(0, 2),
|
||||
(0, 1),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_symbols", [2, 3, 26, 26 + 26, 256 - 140, 300])
|
||||
def test_large_path(num_symbols: int) -> None:
|
||||
symbols = "".join(oe.get_symbol(i) for i in range(num_symbols))
|
||||
dimension_dict = dict(zip(symbols, itertools.cycle([2, 3, 4])))
|
||||
expression = ",".join(symbols[t : t + 2] for t in range(num_symbols - 1))
|
||||
tensors = build_shapes(expression, dimension_dict=dimension_dict)
|
||||
|
||||
# Check that path construction does not crash
|
||||
oe.contract_path(expression, *tensors, optimize="greedy", shapes=True)
|
||||
|
||||
|
||||
def test_custom_random_greedy() -> None:
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
eq, shapes = rand_equation(10, 4, seed=42)
|
||||
views = list(map(np.ones, shapes))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
oe.RandomGreedy(minimize="something")
|
||||
|
||||
optimizer = oe.RandomGreedy(max_repeats=10, minimize="flops")
|
||||
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
|
||||
|
||||
assert len(optimizer.costs) == 10
|
||||
assert len(optimizer.sizes) == 10
|
||||
|
||||
assert path == optimizer.path
|
||||
assert optimizer.best["flops"] == min(optimizer.costs)
|
||||
assert path_info.largest_intermediate == optimizer.best["size"]
|
||||
assert path_info.opt_cost == optimizer.best["flops"]
|
||||
|
||||
# check can change settings and run again
|
||||
optimizer.temperature = 0.0
|
||||
optimizer.max_repeats = 6
|
||||
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
|
||||
|
||||
assert len(optimizer.costs) == 16
|
||||
assert len(optimizer.sizes) == 16
|
||||
|
||||
assert path == optimizer.path
|
||||
assert optimizer.best["size"] == min(optimizer.sizes)
|
||||
assert path_info.largest_intermediate == optimizer.best["size"]
|
||||
assert path_info.opt_cost == optimizer.best["flops"]
|
||||
|
||||
# check error if we try and reuse the optimizer on a different expression
|
||||
eq, shapes = rand_equation(10, 4, seed=41)
|
||||
views = list(map(np.ones, shapes))
|
||||
with pytest.raises(ValueError):
|
||||
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
|
||||
|
||||
|
||||
def test_custom_branchbound() -> None:
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
eq, shapes = rand_equation(8, 4, seed=42)
|
||||
views = list(map(np.ones, shapes))
|
||||
optimizer = oe.BranchBound(nbranch=2, cutoff_flops_factor=10, minimize="size")
|
||||
|
||||
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
|
||||
|
||||
assert path == optimizer.path
|
||||
assert path_info.largest_intermediate == optimizer.best["size"]
|
||||
assert path_info.opt_cost == optimizer.best["flops"]
|
||||
|
||||
# tweak settings and run again
|
||||
optimizer.nbranch = 3
|
||||
optimizer.cutoff_flops_factor = 4
|
||||
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
|
||||
|
||||
assert path == optimizer.path
|
||||
assert path_info.largest_intermediate == optimizer.best["size"]
|
||||
assert path_info.opt_cost == optimizer.best["flops"]
|
||||
|
||||
# check error if we try and reuse the optimizer on a different expression
|
||||
eq, shapes = rand_equation(8, 4, seed=41)
|
||||
views = list(map(np.ones, shapes))
|
||||
with pytest.raises(ValueError):
|
||||
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
|
||||
|
||||
|
||||
def test_branchbound_validation() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
oe.BranchBound(nbranch=0)
|
||||
|
||||
|
||||
def test_parallel_random_greedy() -> None:
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
pool = ProcessPoolExecutor(2)
|
||||
|
||||
eq, shapes = rand_equation(10, 4, seed=42)
|
||||
views = list(map(np.ones, shapes))
|
||||
|
||||
optimizer = oe.RandomGreedy(max_repeats=10, parallel=pool)
|
||||
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
|
||||
|
||||
assert len(optimizer.costs) == 10
|
||||
assert len(optimizer.sizes) == 10
|
||||
|
||||
assert path == optimizer.path
|
||||
assert optimizer.parallel is pool
|
||||
assert optimizer._executor is pool
|
||||
assert optimizer.best["flops"] == min(optimizer.costs)
|
||||
assert path_info.largest_intermediate == optimizer.best["size"]
|
||||
assert path_info.opt_cost == optimizer.best["flops"]
|
||||
|
||||
# now switch to max time algorithm
|
||||
optimizer.max_repeats = int(1e6)
|
||||
optimizer.max_time = 0.2
|
||||
optimizer.parallel = 2
|
||||
|
||||
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
|
||||
|
||||
assert len(optimizer.costs) > 10
|
||||
assert len(optimizer.sizes) > 10
|
||||
|
||||
assert path == optimizer.path
|
||||
assert optimizer.best["flops"] == min(optimizer.costs)
|
||||
assert path_info.largest_intermediate == optimizer.best["size"]
|
||||
assert path_info.opt_cost == optimizer.best["flops"]
|
||||
|
||||
optimizer.parallel = True
|
||||
assert optimizer._executor is not None
|
||||
assert optimizer._executor is not pool
|
||||
|
||||
are_done = [f.running() or f.done() for f in optimizer._futures]
|
||||
assert all(are_done)
|
||||
|
||||
|
||||
def test_custom_path_optimizer() -> None:
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
class NaiveOptimizer(oe.paths.PathOptimizer):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: List[ArrayIndexType],
|
||||
output: ArrayIndexType,
|
||||
size_dict: Dict[str, int],
|
||||
memory_limit: Optional[int] = None,
|
||||
) -> PathType:
|
||||
self.was_used = True
|
||||
return [(0, 1)] * (len(inputs) - 1)
|
||||
|
||||
eq, shapes = rand_equation(5, 3, seed=42, d_max=3)
|
||||
views = list(map(np.ones, shapes))
|
||||
|
||||
exp = oe.contract(eq, *views, optimize=False)
|
||||
|
||||
optimizer = NaiveOptimizer()
|
||||
out = oe.contract(eq, *views, optimize=optimizer)
|
||||
assert exp == out
|
||||
assert optimizer.was_used
|
||||
|
||||
|
||||
def test_custom_random_optimizer() -> None:
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
class NaiveRandomOptimizer(oe.path_random.RandomOptimizer):
|
||||
@staticmethod
|
||||
def random_path(
|
||||
r: int, n: int, inputs: List[ArrayIndexType], output: ArrayIndexType, size_dict: Dict[str, int]
|
||||
) -> Any:
|
||||
"""Picks a completely random contraction order."""
|
||||
np.random.seed(r)
|
||||
ssa_path: List[TensorShapeType] = []
|
||||
remaining = set(range(n))
|
||||
while len(remaining) > 1:
|
||||
i, j = np.random.choice(list(remaining), size=2, replace=False)
|
||||
remaining.add(n + len(ssa_path))
|
||||
remaining.remove(i)
|
||||
remaining.remove(j)
|
||||
ssa_path.append((i, j))
|
||||
cost, size = oe.path_random.ssa_path_compute_cost(ssa_path, inputs, output, size_dict)
|
||||
return ssa_path, cost, size
|
||||
|
||||
def setup(self, inputs: Any, output: Any, size_dict: Any) -> Any:
|
||||
self.was_used = True
|
||||
n = len(inputs)
|
||||
trial_fn = self.random_path
|
||||
trial_args = (n, inputs, output, size_dict)
|
||||
return trial_fn, trial_args
|
||||
|
||||
eq, shapes = rand_equation(5, 3, seed=42, d_max=3)
|
||||
views = list(map(np.ones, shapes))
|
||||
|
||||
exp = oe.contract(eq, *views, optimize=False)
|
||||
|
||||
optimizer = NaiveRandomOptimizer(max_repeats=16)
|
||||
out = oe.contract(eq, *views, optimize=optimizer)
|
||||
assert exp == out
|
||||
assert optimizer.was_used
|
||||
|
||||
assert len(optimizer.costs) == 16
|
||||
|
||||
|
||||
def test_optimizer_registration() -> None:
|
||||
def custom_optimizer(
|
||||
inputs: List[ArrayIndexType], output: ArrayIndexType, size_dict: Dict[str, int], memory_limit: Optional[int]
|
||||
) -> PathType:
|
||||
return [(0, 1)] * (len(inputs) - 1)
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
oe.paths.register_path_fn("optimal", custom_optimizer)
|
||||
|
||||
oe.paths.register_path_fn("custom", custom_optimizer)
|
||||
assert "custom" in oe.paths._PATH_OPTIONS
|
||||
|
||||
eq = "ab,bc,cd"
|
||||
shapes = [(2, 3), (3, 4), (4, 5)]
|
||||
path, _ = oe.contract_path(eq, *shapes, shapes=True, optimize="custom") # type: ignore
|
||||
assert path == [(0, 1), (0, 1)]
|
||||
del oe.paths._PATH_OPTIONS["custom"]
|
||||
|
||||
|
||||
def test_path_with_assumed_shapes() -> None:
|
||||
path, _ = oe.contract_path("ab,bc,cd", [[5, 3]], [[2], [4]], [[3, 2]])
|
||||
assert path == [(0, 1), (0, 1)]
|
||||
@@ -0,0 +1,390 @@
|
||||
import itertools
|
||||
import weakref
|
||||
from collections import Counter
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from opt_einsum import contract, contract_expression, contract_path, get_symbol, shared_intermediates
|
||||
from opt_einsum.backends import to_cupy, to_torch
|
||||
from opt_einsum.contract import _einsum
|
||||
from opt_einsum.parser import parse_einsum_input
|
||||
from opt_einsum.sharing import count_cached_ops, currently_sharing, get_sharing_cache
|
||||
from opt_einsum.testing import build_views
|
||||
from opt_einsum.typing import BackendType
|
||||
|
||||
pytest.importorskip("numpy")
|
||||
|
||||
try:
|
||||
import numpy as np # type: ignore
|
||||
|
||||
numpy_if_found = "numpy"
|
||||
except ImportError:
|
||||
numpy_if_found = pytest.param("numpy", marks=[pytest.mark.skip(reason="NumPy not installed.")]) # type: ignore
|
||||
|
||||
try:
|
||||
import cupy # noqa
|
||||
|
||||
cupy_if_found = "cupy"
|
||||
except ImportError:
|
||||
cupy_if_found = pytest.param("cupy", marks=[pytest.mark.skip(reason="CuPy not installed.")]) # type: ignore
|
||||
|
||||
try:
|
||||
import torch # type: ignore # noqa
|
||||
|
||||
torch_if_found = "torch"
|
||||
except ImportError:
|
||||
torch_if_found = pytest.param("torch", marks=[pytest.mark.skip(reason="PyTorch not installed.")]) # type: ignore
|
||||
|
||||
backends = [numpy_if_found, torch_if_found, cupy_if_found]
|
||||
equations = [
|
||||
"ab,bc->ca",
|
||||
"abc,bcd,dea",
|
||||
"abc,def->fedcba",
|
||||
"abc,bcd,df->fa",
|
||||
# test 'prefer einsum' ops
|
||||
"ijk,ikj",
|
||||
"i,j->ij",
|
||||
"ijk,k->ij",
|
||||
"AB,BC->CA",
|
||||
]
|
||||
to_backend = {
|
||||
"numpy": lambda x: x,
|
||||
"torch": to_torch,
|
||||
"cupy": to_cupy,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("eq", equations)
|
||||
@pytest.mark.parametrize("backend", backends)
|
||||
def test_sharing_value(eq: str, backend: BackendType) -> None:
|
||||
views = build_views(eq)
|
||||
shapes = [v.shape for v in views]
|
||||
expr = contract_expression(eq, *shapes)
|
||||
|
||||
expected = expr(*views, backend=backend)
|
||||
with shared_intermediates():
|
||||
actual = expr(*views, backend=backend)
|
||||
|
||||
assert (actual == expected).all()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend", backends)
|
||||
def test_complete_sharing(backend: BackendType) -> None:
|
||||
eq = "ab,bc,cd->"
|
||||
views = build_views(eq)
|
||||
expr = contract_expression(eq, *(v.shape for v in views))
|
||||
|
||||
print("-" * 40)
|
||||
print("Without sharing:")
|
||||
with shared_intermediates() as cache:
|
||||
expr(*views, backend=backend)
|
||||
expected = count_cached_ops(cache)
|
||||
|
||||
print("-" * 40)
|
||||
print("With sharing:")
|
||||
with shared_intermediates() as cache:
|
||||
expr(*views, backend=backend)
|
||||
expr(*views, backend=backend)
|
||||
actual = count_cached_ops(cache)
|
||||
|
||||
print("-" * 40)
|
||||
print(f"Without sharing: {expected} expressions")
|
||||
print(f"With sharing: {actual} expressions")
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend", backends)
|
||||
def test_sharing_reused_cache(backend: BackendType) -> None:
|
||||
eq = "ab,bc,cd->"
|
||||
views = build_views(eq)
|
||||
expr = contract_expression(eq, *(v.shape for v in views))
|
||||
|
||||
print("-" * 40)
|
||||
print("Without sharing:")
|
||||
with shared_intermediates() as cache:
|
||||
expr(*views, backend=backend)
|
||||
expected = count_cached_ops(cache)
|
||||
|
||||
print("-" * 40)
|
||||
print("With sharing:")
|
||||
with shared_intermediates() as cache:
|
||||
expr(*views, backend=backend)
|
||||
with shared_intermediates(cache):
|
||||
expr(*views, backend=backend)
|
||||
actual = count_cached_ops(cache)
|
||||
|
||||
print("-" * 40)
|
||||
print(f"Without sharing: {expected} expressions")
|
||||
print(f"With sharing: {actual} expressions")
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend", backends)
|
||||
def test_no_sharing_separate_cache(backend: BackendType) -> None:
|
||||
eq = "ab,bc,cd->"
|
||||
views = build_views(eq)
|
||||
expr = contract_expression(eq, *(v.shape for v in views))
|
||||
|
||||
print("-" * 40)
|
||||
print("Without sharing:")
|
||||
with shared_intermediates() as cache:
|
||||
expr(*views, backend=backend)
|
||||
expected = count_cached_ops(cache)
|
||||
expected.update(count_cached_ops(cache)) # we expect double
|
||||
|
||||
print("-" * 40)
|
||||
print("With sharing:")
|
||||
with shared_intermediates() as cache1:
|
||||
expr(*views, backend=backend)
|
||||
actual = count_cached_ops(cache1)
|
||||
with shared_intermediates() as cache2:
|
||||
expr(*views, backend=backend)
|
||||
actual.update(count_cached_ops(cache2))
|
||||
|
||||
print("-" * 40)
|
||||
print(f"Without sharing: {expected} expressions")
|
||||
print(f"With sharing: {actual} expressions")
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend", backends)
|
||||
def test_sharing_nesting(backend: BackendType) -> None:
|
||||
eqs = ["ab,bc,cd->a", "ab,bc,cd->b", "ab,bc,cd->c", "ab,bc,cd->c"]
|
||||
views = build_views(eqs[0])
|
||||
shapes = [v.shape for v in views]
|
||||
refs: Any = weakref.WeakValueDictionary()
|
||||
|
||||
def method1(views):
|
||||
with shared_intermediates():
|
||||
w = contract_expression(eqs[0], *shapes)(*views, backend=backend)
|
||||
x = contract_expression(eqs[2], *shapes)(*views, backend=backend)
|
||||
result = contract_expression("a,b->", w.shape, x.shape)(w, x, backend=backend)
|
||||
refs["w"] = w
|
||||
refs["x"] = x
|
||||
del w, x
|
||||
assert "w" in refs
|
||||
assert "x" in refs
|
||||
assert "w" not in refs, "cache leakage"
|
||||
assert "x" not in refs, "cache leakage"
|
||||
return result
|
||||
|
||||
def method2(views):
|
||||
with shared_intermediates():
|
||||
y = contract_expression(eqs[2], *shapes)(*views, backend=backend)
|
||||
z = contract_expression(eqs[3], *shapes)(*views, backend=backend)
|
||||
refs["y"] = y
|
||||
refs["z"] = z
|
||||
result = contract_expression("c,d->", y.shape, z.shape)(y, z, backend=backend)
|
||||
result = result + method1(views) # nest method1 in method2
|
||||
del y, z
|
||||
assert "y" in refs
|
||||
assert "z" in refs
|
||||
assert "y" not in refs
|
||||
assert "z" not in refs
|
||||
|
||||
method1(views)
|
||||
method2(views)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("eq", equations)
|
||||
@pytest.mark.parametrize("backend", backends)
|
||||
def test_sharing_modulo_commutativity(eq: str, backend: BackendType) -> None:
|
||||
ops = tuple(to_backend[backend](x) for x in build_views(eq))
|
||||
inputs, output, _ = parse_einsum_input([eq] + list(ops))
|
||||
inputs_list = inputs.split(",")
|
||||
|
||||
print("-" * 40)
|
||||
print("Without sharing:")
|
||||
with shared_intermediates() as cache:
|
||||
_einsum(eq, *ops, backend=backend)
|
||||
expected = count_cached_ops(cache)
|
||||
|
||||
print("-" * 40)
|
||||
print("With sharing:")
|
||||
with shared_intermediates() as cache:
|
||||
for permuted in itertools.permutations(zip(inputs_list, ops)):
|
||||
permuted_inputs = [p[0] for p in permuted]
|
||||
permuted_ops = [p[1] for p in permuted]
|
||||
permuted_eq = "{}->{}".format(",".join(permuted_inputs), output)
|
||||
_einsum(permuted_eq, *permuted_ops, backend=backend)
|
||||
actual = count_cached_ops(cache)
|
||||
|
||||
print("-" * 40)
|
||||
print(f"Without sharing: {expected} expressions")
|
||||
print(f"With sharing: {actual} expressions")
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend", backends)
|
||||
def test_partial_sharing(backend: BackendType) -> None:
|
||||
eq = "ab,bc,de->"
|
||||
x, y, z1 = build_views(eq) # type: ignore
|
||||
z2 = 2.0 * z1 - 1.0
|
||||
expr = contract_expression(eq, x.shape, y.shape, z1.shape)
|
||||
|
||||
print("-" * 40)
|
||||
print("Without sharing:")
|
||||
num_exprs_nosharing: Any = Counter()
|
||||
with shared_intermediates() as cache:
|
||||
expr(x, y, z1, backend=backend)
|
||||
num_exprs_nosharing.update(count_cached_ops(cache))
|
||||
with shared_intermediates() as cache:
|
||||
expr(x, y, z2, backend=backend)
|
||||
num_exprs_nosharing.update(count_cached_ops(cache))
|
||||
|
||||
print("-" * 40)
|
||||
print("With sharing:")
|
||||
with shared_intermediates() as cache:
|
||||
expr(x, y, z1, backend=backend)
|
||||
expr(x, y, z2, backend=backend)
|
||||
num_exprs_sharing = count_cached_ops(cache)
|
||||
|
||||
print("-" * 40)
|
||||
print(f"Without sharing: {num_exprs_nosharing} expressions")
|
||||
print(f"With sharing: {num_exprs_sharing} expressions")
|
||||
assert num_exprs_nosharing["einsum"] > num_exprs_sharing["einsum"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend", backends)
|
||||
def test_sharing_with_constants(backend: BackendType) -> None:
|
||||
inputs = "ij,jk,kl"
|
||||
outputs = "ijkl"
|
||||
equations = [f"{inputs}->{output}" for output in outputs]
|
||||
shapes = (2, 3), (3, 4), (4, 5)
|
||||
constants = {0, 2}
|
||||
ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)]
|
||||
var = np.random.rand(*shapes[1])
|
||||
|
||||
expected = [contract_expression(eq, *shapes)(ops[0], var, ops[2]) for eq in equations]
|
||||
|
||||
with shared_intermediates():
|
||||
actual = [contract_expression(eq, *ops, constants=constants)(var) for eq in equations]
|
||||
|
||||
for dim, expected_dim, actual_dim in zip(outputs, expected, actual):
|
||||
assert np.allclose(expected_dim, actual_dim), f"error at {dim}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("size", [3, 4, 5])
|
||||
@pytest.mark.parametrize("backend", backends)
|
||||
def test_chain(size: int, backend: BackendType) -> None:
|
||||
xs = [np.random.rand(2, 2) for _ in range(size)]
|
||||
shapes = [x.shape for x in xs]
|
||||
alphabet = "".join(get_symbol(i) for i in range(size + 1))
|
||||
names = [alphabet[i : i + 2] for i in range(size)]
|
||||
inputs = ",".join(names)
|
||||
|
||||
with shared_intermediates():
|
||||
print(inputs)
|
||||
for i in range(size + 1):
|
||||
target = alphabet[i]
|
||||
eq = f"{inputs}->{target}"
|
||||
path_info = contract_path(eq, *xs)
|
||||
print(path_info[1])
|
||||
expr = contract_expression(eq, *shapes)
|
||||
expr(*xs, backend=backend)
|
||||
print("-" * 40)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("size", [3, 4, 5, 10])
|
||||
@pytest.mark.parametrize("backend", backends)
|
||||
def test_chain_2(size: int, backend: BackendType) -> None:
|
||||
xs = [np.random.rand(2, 2) for _ in range(size)]
|
||||
shapes = [x.shape for x in xs]
|
||||
alphabet = "".join(get_symbol(i) for i in range(size + 1))
|
||||
names = [alphabet[i : i + 2] for i in range(size)]
|
||||
inputs = ",".join(names)
|
||||
|
||||
with shared_intermediates():
|
||||
print(inputs)
|
||||
for i in range(size):
|
||||
target = alphabet[i : i + 2]
|
||||
eq = f"{inputs}->{target}"
|
||||
path_info = contract_path(eq, *xs)
|
||||
print(path_info[1])
|
||||
expr = contract_expression(eq, *shapes)
|
||||
expr(*xs, backend=backend)
|
||||
print("-" * 40)
|
||||
|
||||
|
||||
def _compute_cost(cache):
|
||||
counts = count_cached_ops(cache)
|
||||
return counts["einsum"] + counts["tensordot"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend", backends)
|
||||
def test_chain_2_growth(backend: BackendType) -> None:
|
||||
sizes = list(range(1, 21))
|
||||
costs = []
|
||||
for size in sizes:
|
||||
xs = [np.random.rand(2, 2) for _ in range(size)]
|
||||
alphabet = "".join(get_symbol(i) for i in range(size + 1))
|
||||
names = [alphabet[i : i + 2] for i in range(size)]
|
||||
inputs = ",".join(names)
|
||||
|
||||
with shared_intermediates() as cache:
|
||||
for i in range(size):
|
||||
target = alphabet[i : i + 2]
|
||||
eq = f"{inputs}->{target}"
|
||||
expr = contract_expression(eq, *(x.shape for x in xs))
|
||||
expr(*xs, backend=backend)
|
||||
costs.append(_compute_cost(cache))
|
||||
|
||||
print(f"sizes = {repr(sizes)}")
|
||||
print(f"costs = {repr(costs)}")
|
||||
for size, cost in zip(sizes, costs):
|
||||
print(f"{size}\t{cost}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("size", [3, 4, 5])
|
||||
@pytest.mark.parametrize("backend", backends)
|
||||
def test_chain_sharing(size: int, backend: BackendType) -> None:
|
||||
xs = [np.random.rand(2, 2) for _ in range(size)]
|
||||
alphabet = "".join(get_symbol(i) for i in range(size + 1))
|
||||
names = [alphabet[i : i + 2] for i in range(size)]
|
||||
inputs = ",".join(names)
|
||||
|
||||
num_exprs_nosharing = 0
|
||||
for i in range(size + 1):
|
||||
with shared_intermediates() as cache:
|
||||
target = alphabet[i]
|
||||
eq = f"{inputs}->{target}"
|
||||
expr = contract_expression(eq, *tuple(x.shape for x in xs))
|
||||
expr(*xs, backend=backend)
|
||||
num_exprs_nosharing += _compute_cost(cache)
|
||||
|
||||
with shared_intermediates() as cache:
|
||||
print(inputs)
|
||||
for i in range(size + 1):
|
||||
target = alphabet[i]
|
||||
eq = f"{inputs}->{target}"
|
||||
path_info = contract_path(eq, *xs)
|
||||
print(path_info[1])
|
||||
expr = contract_expression(eq, *[x.shape for x in xs])
|
||||
expr(*xs, backend=backend)
|
||||
num_exprs_sharing = _compute_cost(cache)
|
||||
|
||||
print("-" * 40)
|
||||
print(f"Without sharing: {num_exprs_nosharing} expressions")
|
||||
print(f"With sharing: {num_exprs_sharing} expressions")
|
||||
assert num_exprs_nosharing > num_exprs_sharing
|
||||
|
||||
|
||||
def test_multithreaded_sharing() -> None:
|
||||
from multiprocessing.pool import ThreadPool
|
||||
|
||||
def fn():
|
||||
x, y, z = build_views("ab,bc,cd")
|
||||
|
||||
with shared_intermediates():
|
||||
contract("ab,bc,cd->a", x, y, z)
|
||||
contract("ab,bc,cd->b", x, y, z)
|
||||
|
||||
return len(get_sharing_cache())
|
||||
|
||||
expected = fn()
|
||||
pool = ThreadPool(8)
|
||||
fs = [pool.apply_async(fn) for _ in range(16)]
|
||||
assert not currently_sharing()
|
||||
assert [f.get() for f in fs] == [expected] * 16
|
||||
pool.close()
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Types used in the opt_einsum package."""
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import Any, Callable, Collection, Dict, FrozenSet, List, Literal, Optional, Tuple, Union
|
||||
|
||||
TensorShapeType = Tuple[int, ...]
|
||||
PathType = Collection[TensorShapeType]
|
||||
|
||||
ArrayType = Any
|
||||
|
||||
ArrayIndexType = FrozenSet[str]
|
||||
ArrayShaped = namedtuple("ArrayShaped", ["shape"])
|
||||
|
||||
ContractionListType = List[Tuple[Any, ArrayIndexType, str, Optional[Tuple[str, ...]], Union[str, bool]]]
|
||||
PathSearchFunctionType = Callable[[List[ArrayIndexType], ArrayIndexType, Dict[str, int], Optional[int]], PathType]
|
||||
|
||||
# Contract kwargs
|
||||
OptimizeKind = Union[
|
||||
None,
|
||||
bool,
|
||||
Literal[
|
||||
"optimal", "dp", "greedy", "random-greedy", "random-greedy-128", "branch-all", "branch-2", "auto", "auto-hq"
|
||||
],
|
||||
PathType,
|
||||
PathSearchFunctionType,
|
||||
]
|
||||
BackendType = Literal["auto", "object", "autograd", "cupy", "dask", "jax", "theano", "tensorflow", "torch", "libjax"]
|
||||
Reference in New Issue
Block a user