hand
This commit is contained in:
@@ -0,0 +1,28 @@
|
||||
"""Compute backends for opt_einsum."""
|
||||
|
||||
# Backends
|
||||
from opt_einsum.backends.cupy import to_cupy
|
||||
from opt_einsum.backends.dispatch import (
|
||||
build_expression,
|
||||
evaluate_constants,
|
||||
get_func,
|
||||
has_backend,
|
||||
has_einsum,
|
||||
has_tensordot,
|
||||
)
|
||||
from opt_einsum.backends.tensorflow import to_tensorflow
|
||||
from opt_einsum.backends.theano import to_theano
|
||||
from opt_einsum.backends.torch import to_torch
|
||||
|
||||
__all__ = [
|
||||
"get_func",
|
||||
"has_einsum",
|
||||
"has_tensordot",
|
||||
"build_expression",
|
||||
"evaluate_constants",
|
||||
"has_backend",
|
||||
"to_tensorflow",
|
||||
"to_theano",
|
||||
"to_cupy",
|
||||
"to_torch",
|
||||
]
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,32 @@
|
||||
"""Required functions for optimized contractions of numpy arrays using cupy."""
|
||||
|
||||
from opt_einsum.helpers import has_array_interface
|
||||
from opt_einsum.sharing import to_backend_cache_wrap
|
||||
|
||||
__all__ = ["to_cupy", "build_expression", "evaluate_constants"]
|
||||
|
||||
|
||||
@to_backend_cache_wrap
|
||||
def to_cupy(array): # pragma: no cover
|
||||
import cupy
|
||||
|
||||
if has_array_interface(array):
|
||||
return cupy.asarray(array)
|
||||
|
||||
return array
|
||||
|
||||
|
||||
def build_expression(_, expr): # pragma: no cover
|
||||
"""Build a cupy function based on ``arrays`` and ``expr``."""
|
||||
|
||||
def cupy_contract(*arrays):
|
||||
return expr._contract([to_cupy(x) for x in arrays], backend="cupy").get()
|
||||
|
||||
return cupy_contract
|
||||
|
||||
|
||||
def evaluate_constants(const_arrays, expr): # pragma: no cover
|
||||
"""Convert constant arguments to cupy arrays, and perform any possible
|
||||
constant contractions.
|
||||
"""
|
||||
return expr(*[to_cupy(x) for x in const_arrays], backend="cupy", evaluate_constants=True)
|
||||
@@ -0,0 +1,156 @@
|
||||
"""Handles dispatching array operations to the correct backend library, as well
|
||||
as converting arrays to backend formats and then potentially storing them as
|
||||
constants.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
from opt_einsum.backends import cupy as _cupy
|
||||
from opt_einsum.backends import jax as _jax
|
||||
from opt_einsum.backends import object_arrays
|
||||
from opt_einsum.backends import tensorflow as _tensorflow
|
||||
from opt_einsum.backends import theano as _theano
|
||||
from opt_einsum.backends import torch as _torch
|
||||
|
||||
__all__ = [
|
||||
"get_func",
|
||||
"has_einsum",
|
||||
"has_tensordot",
|
||||
"build_expression",
|
||||
"evaluate_constants",
|
||||
"has_backend",
|
||||
]
|
||||
|
||||
# known non top-level imports
|
||||
_aliases = {
|
||||
"dask": "dask.array",
|
||||
"theano": "theano.tensor",
|
||||
"torch": "opt_einsum.backends.torch",
|
||||
"jax": "jax.numpy",
|
||||
"jaxlib": "jax.numpy",
|
||||
"autograd": "autograd.numpy",
|
||||
"mars": "mars.tensor",
|
||||
}
|
||||
|
||||
|
||||
def _import_func(func: str, backend: str, default: Any = None) -> Any:
|
||||
"""Try and import ``{backend}.{func}``.
|
||||
If library is installed and func is found, return the func;
|
||||
otherwise if default is provided, return default;
|
||||
otherwise raise an error.
|
||||
"""
|
||||
try:
|
||||
lib = importlib.import_module(_aliases.get(backend, backend))
|
||||
return getattr(lib, func) if default is None else getattr(lib, func, default)
|
||||
except AttributeError:
|
||||
error_msg = (
|
||||
"{} doesn't seem to provide the function {} - see "
|
||||
"https://optimized-einsum.readthedocs.io/en/latest/backends.html "
|
||||
"for details on which functions are required for which contractions."
|
||||
)
|
||||
raise AttributeError(error_msg.format(backend, func))
|
||||
|
||||
|
||||
# manually cache functions as python2 doesn't support functools.lru_cache
|
||||
# other libs will be added to this if needed, but pre-populate with numpy
|
||||
_cached_funcs: Dict[Tuple[str, str], Any] = {
|
||||
("einsum", "object"): object_arrays.object_einsum,
|
||||
}
|
||||
|
||||
try:
|
||||
import numpy as np # type: ignore
|
||||
|
||||
_cached_funcs[("tensordot", "numpy")] = np.tensordot
|
||||
_cached_funcs[("transpose", "numpy")] = np.transpose
|
||||
_cached_funcs[("einsum", "numpy")] = np.einsum
|
||||
# also pre-populate with the arbitrary object backend
|
||||
_cached_funcs[("tensordot", "object")] = np.tensordot
|
||||
_cached_funcs[("transpose", "object")] = np.transpose
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
def get_func(func: str, backend: str = "numpy", default: Any = None) -> Any:
|
||||
"""Return ``{backend}.{func}``, e.g. ``numpy.einsum``,
|
||||
or a default func if provided. Cache result.
|
||||
"""
|
||||
try:
|
||||
return _cached_funcs[func, backend]
|
||||
except KeyError:
|
||||
fn = _import_func(func, backend, default)
|
||||
_cached_funcs[func, backend] = fn
|
||||
return fn
|
||||
|
||||
|
||||
# mark libs with einsum, else try to use tensordot/transpose as much as possible
|
||||
_has_einsum: Dict[str, bool] = {}
|
||||
|
||||
|
||||
def has_einsum(backend: str) -> bool:
|
||||
"""Check if ``{backend}.einsum`` exists, cache result for performance."""
|
||||
try:
|
||||
return _has_einsum[backend]
|
||||
except KeyError:
|
||||
try:
|
||||
get_func("einsum", backend)
|
||||
_has_einsum[backend] = True
|
||||
except AttributeError:
|
||||
_has_einsum[backend] = False
|
||||
|
||||
return _has_einsum[backend]
|
||||
|
||||
|
||||
_has_tensordot: Dict[str, bool] = {}
|
||||
|
||||
|
||||
def has_tensordot(backend: str) -> bool:
|
||||
"""Check if ``{backend}.tensordot`` exists, cache result for performance."""
|
||||
try:
|
||||
return _has_tensordot[backend]
|
||||
except KeyError:
|
||||
try:
|
||||
get_func("tensordot", backend)
|
||||
_has_tensordot[backend] = True
|
||||
except AttributeError:
|
||||
_has_tensordot[backend] = False
|
||||
|
||||
return _has_tensordot[backend]
|
||||
|
||||
|
||||
# Dispatch to correct expression backend
|
||||
# these are the backends which support explicit to-and-from numpy conversion
|
||||
CONVERT_BACKENDS = {
|
||||
"tensorflow": _tensorflow.build_expression,
|
||||
"theano": _theano.build_expression,
|
||||
"cupy": _cupy.build_expression,
|
||||
"torch": _torch.build_expression,
|
||||
"jax": _jax.build_expression,
|
||||
}
|
||||
|
||||
EVAL_CONSTS_BACKENDS = {
|
||||
"tensorflow": _tensorflow.evaluate_constants,
|
||||
"theano": _theano.evaluate_constants,
|
||||
"cupy": _cupy.evaluate_constants,
|
||||
"torch": _torch.evaluate_constants,
|
||||
"jax": _jax.evaluate_constants,
|
||||
}
|
||||
|
||||
|
||||
def build_expression(backend, arrays, expr):
|
||||
"""Build an expression, based on ``expr`` and initial arrays ``arrays``,
|
||||
that evaluates using backend ``backend``.
|
||||
"""
|
||||
return CONVERT_BACKENDS[backend](arrays, expr)
|
||||
|
||||
|
||||
def evaluate_constants(backend, arrays, expr):
|
||||
"""Convert constant arrays to the correct backend, and perform as much of
|
||||
the contraction of ``expr`` with these as possible.
|
||||
"""
|
||||
return EVAL_CONSTS_BACKENDS[backend](arrays, expr)
|
||||
|
||||
|
||||
def has_backend(backend: str) -> bool:
|
||||
"""Checks if the backend is known."""
|
||||
return backend.lower() in CONVERT_BACKENDS
|
||||
@@ -0,0 +1,45 @@
|
||||
"""Required functions for optimized contractions of numpy arrays using jax."""
|
||||
|
||||
from opt_einsum.sharing import to_backend_cache_wrap
|
||||
|
||||
__all__ = ["build_expression", "evaluate_constants"]
|
||||
|
||||
_JAX = None
|
||||
|
||||
|
||||
def _get_jax_and_to_jax():
|
||||
global _JAX
|
||||
if _JAX is None:
|
||||
import jax # type: ignore
|
||||
|
||||
@to_backend_cache_wrap
|
||||
@jax.jit
|
||||
def to_jax(x):
|
||||
return x
|
||||
|
||||
_JAX = jax, to_jax
|
||||
|
||||
return _JAX
|
||||
|
||||
|
||||
def build_expression(_, expr): # pragma: no cover
|
||||
"""Build a jax function based on ``arrays`` and ``expr``."""
|
||||
jax, _ = _get_jax_and_to_jax()
|
||||
|
||||
jax_expr = jax.jit(expr._contract)
|
||||
|
||||
def jax_contract(*arrays):
|
||||
import numpy as np # type: ignore
|
||||
|
||||
return np.asarray(jax_expr(arrays))
|
||||
|
||||
return jax_contract
|
||||
|
||||
|
||||
def evaluate_constants(const_arrays, expr): # pragma: no cover
|
||||
"""Convert constant arguments to jax arrays, and perform any possible
|
||||
constant contractions.
|
||||
"""
|
||||
jax, to_jax = _get_jax_and_to_jax()
|
||||
|
||||
return expr(*[to_jax(x) for x in const_arrays], backend="jax", evaluate_constants=True)
|
||||
@@ -0,0 +1,59 @@
|
||||
"""Functions for performing contractions with array elements which are objects."""
|
||||
|
||||
import functools
|
||||
import operator
|
||||
|
||||
from opt_einsum.typing import ArrayType
|
||||
|
||||
|
||||
def object_einsum(eq: str, *arrays: ArrayType) -> ArrayType:
|
||||
"""A ``einsum`` implementation for ``numpy`` arrays with object dtype.
|
||||
The loop is performed in python, meaning the objects themselves need
|
||||
only to implement ``__mul__`` and ``__add__`` for the contraction to be
|
||||
computed. This may be useful when, for example, computing expressions of
|
||||
tensors with symbolic elements, but note it will be very slow when compared
|
||||
to ``numpy.einsum`` and numeric data types!
|
||||
|
||||
Parameters
|
||||
----------
|
||||
eq : str
|
||||
The contraction string, should specify output.
|
||||
arrays : sequence of arrays
|
||||
These can be any indexable arrays as long as addition and
|
||||
multiplication is defined on the elements.
|
||||
|
||||
Returns:
|
||||
-------
|
||||
out : numpy.ndarray
|
||||
The output tensor, with ``dtype=object``.
|
||||
"""
|
||||
import numpy as np # type: ignore
|
||||
|
||||
# when called by ``opt_einsum`` we will always be given a full eq
|
||||
lhs, output = eq.split("->")
|
||||
inputs = lhs.split(",")
|
||||
|
||||
sizes = {}
|
||||
for term, array in zip(inputs, arrays):
|
||||
for k, d in zip(term, array.shape):
|
||||
sizes[k] = d
|
||||
|
||||
out_size = tuple(sizes[k] for k in output)
|
||||
out = np.empty(out_size, dtype=object)
|
||||
|
||||
inner = tuple(k for k in sizes if k not in output)
|
||||
inner_size = tuple(sizes[k] for k in inner)
|
||||
|
||||
for coo_o in np.ndindex(*out_size):
|
||||
coord = dict(zip(output, coo_o))
|
||||
|
||||
def gen_inner_sum():
|
||||
for coo_i in np.ndindex(*inner_size):
|
||||
coord.update(dict(zip(inner, coo_i)))
|
||||
locs = (tuple(coord[k] for k in term) for term in inputs)
|
||||
elements = (array[loc] for array, loc in zip(arrays, locs))
|
||||
yield functools.reduce(operator.mul, elements)
|
||||
|
||||
out[coo_o] = functools.reduce(operator.add, gen_inner_sum())
|
||||
|
||||
return out
|
||||
@@ -0,0 +1,123 @@
|
||||
"""Required functions for optimized contractions of numpy arrays using tensorflow."""
|
||||
|
||||
from opt_einsum.helpers import has_array_interface
|
||||
from opt_einsum.sharing import to_backend_cache_wrap
|
||||
|
||||
__all__ = ["to_tensorflow", "build_expression", "evaluate_constants"]
|
||||
|
||||
_CACHED_TF_DEVICE = None
|
||||
|
||||
|
||||
def _get_tensorflow_and_device():
|
||||
global _CACHED_TF_DEVICE
|
||||
|
||||
if _CACHED_TF_DEVICE is None:
|
||||
import tensorflow as tf # type: ignore
|
||||
|
||||
try:
|
||||
eager = tf.executing_eagerly()
|
||||
except AttributeError:
|
||||
try:
|
||||
eager = tf.contrib.eager.in_eager_mode()
|
||||
except AttributeError:
|
||||
eager = False
|
||||
|
||||
device = tf.test.gpu_device_name()
|
||||
if not device:
|
||||
device = "cpu"
|
||||
|
||||
_CACHED_TF_DEVICE = tf, device, eager
|
||||
|
||||
return _CACHED_TF_DEVICE
|
||||
|
||||
|
||||
@to_backend_cache_wrap(constants=True)
|
||||
def to_tensorflow(array, constant=False):
|
||||
"""Convert a numpy array to a ``tensorflow.placeholder`` instance."""
|
||||
tf, device, eager = _get_tensorflow_and_device()
|
||||
|
||||
if eager:
|
||||
if has_array_interface(array):
|
||||
with tf.device(device):
|
||||
return tf.convert_to_tensor(array)
|
||||
|
||||
return array
|
||||
|
||||
if has_array_interface(array):
|
||||
if constant:
|
||||
return tf.convert_to_tensor(array)
|
||||
|
||||
return tf.placeholder(array.dtype, array.shape)
|
||||
|
||||
return array
|
||||
|
||||
|
||||
# Standard graph mode
|
||||
|
||||
|
||||
def build_expression_graph(arrays, expr):
|
||||
"""Build a tensorflow function based on ``arrays`` and ``expr``."""
|
||||
tf, _, _ = _get_tensorflow_and_device()
|
||||
|
||||
placeholders = [to_tensorflow(array) for array in arrays]
|
||||
graph = expr._contract(placeholders, backend="tensorflow")
|
||||
|
||||
def tensorflow_contract(*arrays):
|
||||
session = tf.get_default_session()
|
||||
# only want to feed placeholders - constant tensors already have values
|
||||
feed_dict = {p: a for p, a in zip(placeholders, arrays) if p.op.type == "Placeholder"}
|
||||
return session.run(graph, feed_dict=feed_dict)
|
||||
|
||||
return tensorflow_contract
|
||||
|
||||
|
||||
def evaluate_constants_graph(const_arrays, expr):
|
||||
"""Convert constant arguments to tensorflow constants, and perform any
|
||||
possible constant contractions. Requires evaluating a tensorflow graph.
|
||||
"""
|
||||
tf, _, _ = _get_tensorflow_and_device()
|
||||
|
||||
# compute the partial graph of new inputs
|
||||
const_arrays = [to_tensorflow(x, constant=True) for x in const_arrays]
|
||||
new_ops, new_contraction_list = expr(*const_arrays, backend="tensorflow", evaluate_constants=True)
|
||||
|
||||
# evaluate the new inputs and convert back to tensorflow, maintaining None as non-consts
|
||||
session = tf.get_default_session()
|
||||
new_consts = iter(session.run([x for x in new_ops if x is not None]))
|
||||
new_ops = [None if x is None else to_tensorflow(next(new_consts), constant=True) for x in new_ops]
|
||||
|
||||
return new_ops, new_contraction_list
|
||||
|
||||
|
||||
# Eager execution mode
|
||||
|
||||
|
||||
def build_expression_eager(_, expr):
|
||||
"""Build a eager tensorflow function based on ``arrays`` and ``expr``."""
|
||||
|
||||
def tensorflow_eager_contract(*arrays):
|
||||
return expr._contract([to_tensorflow(x) for x in arrays], backend="tensorflow").numpy()
|
||||
|
||||
return tensorflow_eager_contract
|
||||
|
||||
|
||||
def evaluate_constants_eager(const_arrays, expr):
|
||||
"""Convert constant arguments to tensorflow_eager arrays, and perform any
|
||||
possible constant contractions.
|
||||
"""
|
||||
return expr(*[to_tensorflow(x) for x in const_arrays], backend="tensorflow", evaluate_constants=True)
|
||||
|
||||
|
||||
# Dispatch to eager or graph mode
|
||||
|
||||
|
||||
def build_expression(arrays, expr):
|
||||
_, _, eager = _get_tensorflow_and_device()
|
||||
fn = build_expression_eager if eager else build_expression_graph
|
||||
return fn(arrays, expr)
|
||||
|
||||
|
||||
def evaluate_constants(const_arrays, expr):
|
||||
_, _, eager = _get_tensorflow_and_device()
|
||||
fn = evaluate_constants_eager if eager else evaluate_constants_graph
|
||||
return fn(const_arrays, expr)
|
||||
@@ -0,0 +1,48 @@
|
||||
"""Required functions for optimized contractions of numpy arrays using theano."""
|
||||
|
||||
from opt_einsum.helpers import has_array_interface
|
||||
from opt_einsum.sharing import to_backend_cache_wrap
|
||||
|
||||
__all__ = ["to_theano", "build_expression", "evaluate_constants"]
|
||||
|
||||
|
||||
@to_backend_cache_wrap(constants=True)
|
||||
def to_theano(array, constant=False):
|
||||
"""Convert a numpy array to ``theano.tensor.TensorType`` instance."""
|
||||
import theano # type: ignore
|
||||
|
||||
if has_array_interface(array):
|
||||
if constant:
|
||||
return theano.tensor.constant(array)
|
||||
|
||||
return theano.tensor.TensorType(dtype=array.dtype, broadcastable=[False] * len(array.shape))()
|
||||
|
||||
return array
|
||||
|
||||
|
||||
def build_expression(arrays, expr):
|
||||
"""Build a theano function based on ``arrays`` and ``expr``."""
|
||||
import theano
|
||||
|
||||
in_vars = [to_theano(array) for array in arrays]
|
||||
out_var = expr._contract(in_vars, backend="theano")
|
||||
|
||||
# don't supply constants to graph
|
||||
graph_ins = [x for x in in_vars if not isinstance(x, theano.tensor.TensorConstant)]
|
||||
graph = theano.function(graph_ins, out_var)
|
||||
|
||||
def theano_contract(*arrays):
|
||||
return graph(*[x for x in arrays if not isinstance(x, theano.tensor.TensorConstant)])
|
||||
|
||||
return theano_contract
|
||||
|
||||
|
||||
def evaluate_constants(const_arrays, expr):
|
||||
# compute the partial graph of new inputs
|
||||
const_arrays = [to_theano(x, constant=True) for x in const_arrays]
|
||||
new_ops, new_contraction_list = expr(*const_arrays, backend="theano", evaluate_constants=True)
|
||||
|
||||
# evaluate the new inputs and convert to theano shared tensors
|
||||
new_ops = [None if x is None else to_theano(x.eval(), constant=True) for x in new_ops]
|
||||
|
||||
return new_ops, new_contraction_list
|
||||
@@ -0,0 +1,130 @@
|
||||
"""Required functions for optimized contractions of numpy arrays using pytorch."""
|
||||
|
||||
from opt_einsum.helpers import has_array_interface
|
||||
from opt_einsum.parser import convert_to_valid_einsum_chars
|
||||
from opt_einsum.sharing import to_backend_cache_wrap
|
||||
|
||||
__all__ = [
|
||||
"transpose",
|
||||
"einsum",
|
||||
"tensordot",
|
||||
"to_torch",
|
||||
"build_expression",
|
||||
"evaluate_constants",
|
||||
]
|
||||
|
||||
_TORCH_DEVICE = None
|
||||
_TORCH_HAS_TENSORDOT = None
|
||||
|
||||
_torch_symbols_base = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
|
||||
|
||||
def _get_torch_and_device():
|
||||
global _TORCH_DEVICE
|
||||
global _TORCH_HAS_TENSORDOT
|
||||
|
||||
if _TORCH_DEVICE is None:
|
||||
import torch # type: ignore
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
_TORCH_DEVICE = torch, device
|
||||
_TORCH_HAS_TENSORDOT = hasattr(torch, "tensordot")
|
||||
|
||||
return _TORCH_DEVICE
|
||||
|
||||
|
||||
def transpose(a, axes):
|
||||
"""Normal torch transpose is only valid for 2D matrices."""
|
||||
return a.permute(*axes)
|
||||
|
||||
|
||||
def einsum(equation, *operands, **kwargs):
|
||||
"""Variadic version of torch.einsum to match numpy api."""
|
||||
# rename symbols to support PyTorch 0.4.1 and earlier,
|
||||
# which allow only symbols a-z.
|
||||
equation = convert_to_valid_einsum_chars(equation)
|
||||
|
||||
torch, _ = _get_torch_and_device()
|
||||
return torch.einsum(equation, operands)
|
||||
|
||||
|
||||
def tensordot(x, y, axes=2):
|
||||
"""Simple translation of tensordot syntax to einsum."""
|
||||
torch, _ = _get_torch_and_device()
|
||||
|
||||
if _TORCH_HAS_TENSORDOT:
|
||||
return torch.tensordot(x, y, dims=axes)
|
||||
|
||||
xnd = x.ndimension()
|
||||
ynd = y.ndimension()
|
||||
|
||||
# convert int argument to (list[int], list[int])
|
||||
if isinstance(axes, int):
|
||||
axes = range(xnd - axes, xnd), range(axes)
|
||||
|
||||
# convert (int, int) to (list[int], list[int])
|
||||
if isinstance(axes[0], int):
|
||||
axes = (axes[0],), axes[1]
|
||||
if isinstance(axes[1], int):
|
||||
axes = axes[0], (axes[1],)
|
||||
|
||||
# initialize empty indices
|
||||
x_ix = [None] * xnd
|
||||
y_ix = [None] * ynd
|
||||
out_ix = []
|
||||
|
||||
# fill in repeated indices
|
||||
available_ix = iter(_torch_symbols_base)
|
||||
for ax1, ax2 in zip(*axes):
|
||||
repeat = next(available_ix)
|
||||
x_ix[ax1] = repeat
|
||||
y_ix[ax2] = repeat
|
||||
|
||||
# fill in the rest, and maintain output order
|
||||
for i in range(xnd):
|
||||
if x_ix[i] is None:
|
||||
leave = next(available_ix)
|
||||
x_ix[i] = leave
|
||||
out_ix.append(leave)
|
||||
for i in range(ynd):
|
||||
if y_ix[i] is None:
|
||||
leave = next(available_ix)
|
||||
y_ix[i] = leave
|
||||
out_ix.append(leave)
|
||||
|
||||
# form full string and contract!
|
||||
einsum_str = "{},{}->{}".format(*map("".join, (x_ix, y_ix, out_ix)))
|
||||
return einsum(einsum_str, x, y)
|
||||
|
||||
|
||||
@to_backend_cache_wrap
|
||||
def to_torch(array):
|
||||
torch, device = _get_torch_and_device()
|
||||
|
||||
if has_array_interface(array):
|
||||
return torch.from_numpy(array).to(device)
|
||||
|
||||
return array
|
||||
|
||||
|
||||
def build_expression(_, expr): # pragma: no cover
|
||||
"""Build a torch function based on ``arrays`` and ``expr``."""
|
||||
|
||||
def torch_contract(*arrays):
|
||||
torch_arrays = [to_torch(x) for x in arrays]
|
||||
torch_out = expr._contract(torch_arrays, backend="torch")
|
||||
|
||||
if torch_out.device.type == "cpu":
|
||||
return torch_out.numpy()
|
||||
|
||||
return torch_out.cpu().numpy()
|
||||
|
||||
return torch_contract
|
||||
|
||||
|
||||
def evaluate_constants(const_arrays, expr):
|
||||
"""Convert constant arguments to torch, and perform any possible constant
|
||||
contractions.
|
||||
"""
|
||||
const_arrays = [to_torch(x) for x in const_arrays]
|
||||
return expr(*const_arrays, backend="torch", evaluate_constants=True)
|
||||
Reference in New Issue
Block a user