# Copyright 2026 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Einshape primitive implementation. Einshape (see https://github.com/google-deepmind/einshape) is a DSL for various array transformation operations including reshape, squeeze, expand_dims, and transpose, using an einsum-like notation. The DSL consists of an LHS equation and an RHS equation, separated by `->`. Each side assigns names to dimensions, e.g. `ij` would assign `i` and `j` to the first and second dimensions of an array. The DSL uses parentheses `()` to indicate grouping of dimensions: - On the left-hand side (LHS), parentheses indicate that an existing dimension should be split into multiple dimensions. - On the right-hand side (RHS), parentheses indicate that multiple dimensions should be merged into a single dimension. Dimension reordering in the RHS string (relative to the LHS) specifies a transpose. Example equations: - "n->n": Identity - "ab->ba": Transposes the 0th and 1st dimensions. - "nhwc->nchw": Transposes dimensions from (N, H, W, C) to (N, C, H, W). - "(ab)c->abc": Splits the first dimension into two dimensions (a, b). - "abc->(ab)c": Merges the first two dimensions (a, b). - "a(bc)->(ba)c": Splits the second dimension into (b, c), transposes 0 and 1, then merges dimensions 0 and 1. When used inside a Pallas kernel on TPU, `einshape` will attempt to perform a "tile-preserving" transformation. This is a more efficient implementation that avoids the overhead of general reshapes or transposes by logically reordering the underlying TPU vector registers. This is possible if the transformation preserves the data within the vector registers. If this is not possible, `einshape` will fall back to a general implementation that will likely involve vector register relayouts. As an example, for the equation `a(bc)->bac`, the first step which involves splitting a(bc)->abc is *not* tile preserving (it changes the sublane dimension from a to b), but after the transpose to `bac` it is tile preserving. Therefore the overall `einshape` operation is tile preserving. Not currently supported: - Expand dimensions: `a->1a` - Squeeze dimensions: `a1->a` """ import collections from collections.abc import Sequence import dataclasses import functools import math from typing import Literal, NamedTuple from jax._src import api from jax._src import core as jax_core from jax._src import dispatch from jax._src import typing as jax_typing from jax._src import hijax from jax._src.frozen_dict import FrozenDict from jax._src.interpreters import mlir from jax._src.lax import lax from jax._src.pallas.mosaic import lowering as tpu_lowering from jax._src.pallas.mosaic import tpu_info import jax.numpy as jnp import numpy as np @dataclasses.dataclass(frozen=True) class SplitDims: index: int sizes: tuple[int, ...] def transform_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return (*shape[: self.index], *self.sizes, *shape[self.index + 1 :]) @dataclasses.dataclass(frozen=True) class MergeDims: index: int count: int def transform_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return ( *shape[: self.index], math.prod(shape[self.index : self.index + self.count]), *shape[self.index + self.count :], ) @dataclasses.dataclass(frozen=True) class Transpose: permutation: tuple[int, ...] def transform_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: return tuple(shape[i] for i in self.permutation) # TODO(sharadmv): unify this with other Pallas Transforms Transform = SplitDims | MergeDims | Transpose def _parse_side(s: str) -> list[list[str]]: """Parses one side of an einshape equation into groups of named dimensions. Groups are indicated by parentheses. Dimensions outside of parentheses are treated as groups of size 1. For example: "a(bc)d" -> [['a'], ['b', 'c'], ['d']] "(ab)c" -> [['a', 'b'], ['c']] Args: s: One side of an einshape equation string. Returns: A list of lists of characters, where each inner list represents a group of dimensions. """ # Remove spaces s = s.replace(" ", "") groups = [] i = 0 while i < len(s): if s[i] == "(": # Start of a group j = s.find(")", i) if j == -1: raise ValueError(f"Unmatched parenthesis in {s!r}") group = list(s[i + 1 : j]) groups.append(group) i = j + 1 elif s[i] == ")": raise ValueError(f"Unmatched parenthesis in {s!r}") else: # distinct dimension groups.append([s[i]]) i += 1 return groups def _parse_equation(equation: str) -> tuple[list[list[str]], list[list[str]]]: """Parses an einshape equation.""" if equation.count("->") != 1: raise ValueError("Equation must contain exactly one '->'") lhs_str, rhs_str = equation.split("->") return _parse_side(lhs_str), _parse_side(rhs_str) def get_einshape_transforms( equation: str, input_shape: tuple[int, ...], **sizes: int, ) -> list[Transform]: """Parses an einshape equation into a sequence of transforms. Args: equation: String of the form "ab(cd)->cabd". input_shape: The shape of the input array. **sizes: Integer sizes for dimensions that are split and cannot be inferred. Returns: A list of Split, Transpose, and Merge transforms. """ lhs, rhs = _parse_equation(equation) # Validate LHS against input shape if len(lhs) != len(input_shape): raise ValueError( f"Equation LHS has {len(lhs)} groups but input has {len(input_shape)}" f" dims. LHS: {lhs}, Input shape: {input_shape}" ) dim_sizes: dict[str, int] = {} # Populate known sizes from input for i, group in enumerate(lhs): shape_val = input_shape[i] if len(group) == 1: name = group[0] if name in dim_sizes and dim_sizes[name] != shape_val: raise ValueError( f"Inconsistent size for {name}: {dim_sizes[name]} vs {shape_val}" ) dim_sizes[name] = shape_val else: # We have a merged dimension on LHS, need to split known_product = 1 unknown_dims = [] for name in group: if name in sizes: dim_sizes[name] = sizes[name] known_product *= sizes[name] elif name in dim_sizes: known_product *= dim_sizes[name] else: unknown_dims.append(name) if not unknown_dims: if known_product != shape_val: raise ValueError( f"Size mismatch for group {group}: expected {shape_val}, got" f" {known_product}" ) elif len(unknown_dims) == 1: if shape_val % known_product != 0: raise ValueError( f"Cannot split size {shape_val} with known sizes {known_product}" ) inferred_size = shape_val // known_product dim_sizes[unknown_dims[0]] = inferred_size else: raise ValueError( f"Ambiguous split for {group} with size {shape_val}. Unknowns:" f" {unknown_dims}. Provide sizes via kwargs." ) # Check if all RHS dims are known flat_rhs = [name for group in rhs for name in group] for name in flat_rhs: if name not in dim_sizes: if name in sizes: dim_sizes[name] = sizes[name] else: raise ValueError(f"Unknown dimension {name} in RHS") ops: list[Transform] = [] # 1. Decompose LHS current_idx = 0 for group in lhs: if len(group) > 1: atomic_sizes = tuple(dim_sizes[name] for name in group) ops.append(SplitDims(current_idx, atomic_sizes)) current_idx += len(group) else: current_idx += 1 # 2. Transpose lhs_atomic_order = [name for group in lhs for name in group] rhs_atomic_order = [name for group in rhs for name in group] if set(lhs_atomic_order) != set(rhs_atomic_order): raise NotImplementedError( "Only reordering/splitting/merging supported (no broadcast yet)." ) if lhs_atomic_order != rhs_atomic_order: perm = tuple(lhs_atomic_order.index(name) for name in rhs_atomic_order) ops.append(Transpose(perm)) # 3. Compose RHS current_idx = 0 for group in rhs: if len(group) > 1: ops.append(MergeDims(current_idx, len(group))) current_idx += 1 else: current_idx += 1 return ops def _einshape( equation: str, value: jax_typing.Array, **sizes: int, ) -> jax_typing.Array: """Reshapes and transposes an array according to an einshape equation. Args: equation: String of the form "ab(cd)->cabd". Parentheses indicate grouping of dimensions. On the LHS, grouped dimensions are split. On the RHS, dimensions are merged. value: The array to reshape. **sizes: Integer sizes for dimensions that are split and cannot be inferred. Returns: The reshaped and transposed array. """ transforms = get_einshape_transforms(equation, value.shape, **sizes) for transform in transforms: match transform: case SplitDims(_, _): new_shape = transform.transform_shape(value.shape) value = value.reshape(new_shape) case MergeDims(_, _): new_shape = transform.transform_shape(value.shape) value = value.reshape(new_shape) case Transpose(permutation): value = lax.transpose(value, permutation) return value einshape_lo_p = jax_core.Primitive("einshape_lo") def einshape_lo( equation: str, x: jax_typing.Array, assert_is_tile_preserving: bool, **sizes: int ) -> jax_typing.Array: return einshape_lo_p.bind( x, equation=equation, sizes=tuple(sizes.items()), assert_is_tile_preserving=assert_is_tile_preserving, ) @einshape_lo_p.def_abstract_eval def _einshape_lo_abstract_eval( x_aval: jax_core.ShapedArray, *, equation: str, sizes: tuple[tuple[str, int], ...], assert_is_tile_preserving: bool, ): del assert_is_tile_preserving out_sds = api.eval_shape( functools.partial(_einshape, equation, **dict(sizes)), x_aval ) return x_aval.update(shape=out_sds.shape, dtype=out_sds.dtype) def _einshape_lo_lowering( ctx: mlir.LoweringRuleContext, x, *, equation: str, sizes: tuple[tuple[str, int], ...], assert_is_tile_preserving: bool, ): del assert_is_tile_preserving def f(x): return _einshape(equation, x, **dict(sizes)) return mlir.lower_fun(f, multiple_results=False)(ctx, x) mlir.register_lowering(einshape_lo_p, _einshape_lo_lowering) dispatch.simple_impl(einshape_lo_p) class Einshape(hijax.VJPHiPrimitive): """Einshape primitive.""" def __init__( self, x_aval: jax_core.ShapedArray, *, equation: str, assert_is_tile_preserving: bool, sizes: dict[str, int], ): self.in_avals = (x_aval,) out_type = api.eval_shape( functools.partial(_einshape, equation, **sizes), x_aval ) self.out_aval = hijax.ShapedArray(out_type.shape, out_type.dtype) self.equation = equation self.sizes = sizes self.assert_is_tile_preserving = assert_is_tile_preserving self.params = dict( x_aval=x_aval, equation=equation, sizes=FrozenDict(sizes), assert_is_tile_preserving=assert_is_tile_preserving, ) super().__init__() def expand(self, x: jax_typing.Array) -> jax_typing.Array: # pyrefly: ignore[bad-override] return einshape_lo( self.equation, x, assert_is_tile_preserving=self.assert_is_tile_preserving, **self.sizes, ) def einshape( equation: str, x: jax_typing.Array, assert_is_tile_preserving: bool = False, **sizes: int, ) -> jax_typing.Array: """Reshapes and transposes an array according to an einshape equation. Args: equation: A string defining the transformation, e.g., "ab(cd)->cabd". - Names (e.g., 'a', 'b') represent dimensions. - Parentheses on the LHS, like `(cd)`, indicate a dimension that will be split into dimensions `c` and `d`. - Parentheses on the RHS, like `(ab)`, indicate dimensions `a` and `b` that will be merged. x: The input jax_typing.Array to transform. assert_is_tile_preserving: If True, assert that the transformation is tile preserving. Note that this check only applies inside of Pallas kernels. **sizes: Dimension sizes that cannot be inferred from the input shape. Required when splitting dimensions unless all but one sub-dimension size is known. Returns: The transformed jax_typing.Array. Examples: >>> import jax.numpy as jnp >>> x = jnp.zeros((10, 20)) >>> # Split the second dimension (20) into (4, 5) >>> y = einshape("a(bc)->abc", x, b=4) >>> y.shape (10, 4, 5) >>> # Transpose and merge the first two dimensions. >>> z = einshape("abc->(ba)c", y) >>> z.shape (40, 5) """ return Einshape( jax_core.typeof(x), equation=equation, sizes=sizes, assert_is_tile_preserving=assert_is_tile_preserving, )(x) def _default_einshape_kernel(equation: str, x: jax_typing.Array, **sizes: int): return _einshape(equation, x, **sizes) class Factor(NamedTuple): size: int kind: Literal["outer", "sublane", "lane"] def _array_to_2d_tile_array( x: jax_typing.Array, tiling: tuple[int, ...] ) -> np.ndarray: t1, t2 = tiling[-2:] tiled_shape = tuple(x.shape[i] // tiling[i] for i in range(len(x.shape))) # Allocate an empty object array to ensure Numpy doesn't coerce JAX tracers tiles = np.empty(tiled_shape, dtype=object) for idx in np.ndindex(*tiled_shape): *leading, i1, i2 = idx slices = tuple(leading) + ( slice(i1 * t1, (i1 + 1) * t1), slice(i2 * t2, (i2 + 1) * t2), ) # Standard Integer indexing inherently drops the outer dims -> returns strict 2D array tiles[idx] = x[slices] return tiles def _2d_tile_array_to_array(tiles: np.ndarray) -> jax_typing.Array: raw_arrays = np.empty(tiles.shape, dtype=object) for idx in np.ndindex(*tiles.shape): raw_arrays[idx] = tiles[idx] return jnp.block(raw_arrays.tolist()) def _consolidate(factors: list[Factor]) -> list[Factor]: """Merges contiguous 'outer' factors to allow valid arbitrary outer-dimension reshapes.""" res: list[Factor] = [] for f in factors: if f.kind == "outer" and res and res[-1].kind == "outer": res[-1] = Factor(res[-1].size * f.size, "outer") else: res.append(f) return res def _init_dims(shape: tuple[int, ...], t1: int, t2: int) -> list[list[Factor]]: dims: list[list[Factor]] = [] for i, s in enumerate(shape): if i == len(shape) - 2: kind, t_size = "sublane", t1 elif i == len(shape) - 1: kind, t_size = "lane", t2 else: kind, t_size = "outer", 1 current_dim = [] assert s % t_size == 0 if s // t_size > 1: current_dim.append(Factor(s // t_size, "outer")) if t_size > 1 or kind != "outer": current_dim.append(Factor(t_size, kind)) dims.append(_consolidate(current_dim)) return dims def _apply_split( factors: list[Factor], targets: tuple[int, ...] ) -> list[list[Factor]] | None: factors = _consolidate(factors) queue = collections.deque(factors) result = [] for i, needed in enumerate(targets): new_dim = [] current_size = 1 # Consume factors iteratively until the required shape volume is met while current_size < needed: if not queue: return None b = queue.popleft() # Case A: Perfect match or consume smaller outer block if needed % (current_size * b.size) == 0: new_dim.append(b) current_size *= b.size # Case B: Split a larger block (only allowed over logical outer limits) elif (current_size * b.size) % needed == 0: if b.kind != "outer": return None # Illegal splitting of hardware tile limit take = needed // current_size new_dim.append(Factor(take, "outer")) queue.appendleft(Factor(b.size // take, "outer")) current_size *= take else: return None # Sweep any trailing physical size-1 markers exactly into the right-most split dimension if i == len(targets) - 1: while queue and queue[0].size == 1: new_dim.append(queue.popleft()) result.append(_consolidate(new_dim)) if queue: return None return result def _tile_preserving_einshape_kernel( equation: str, x: jax_typing.Array, **size_vars: int ): tiling = tpu_info.infer_tiling(jax_core.typeof(x)) assert tiling is not None t1, t2 = tiling[-2:] assert isinstance(t1, int) assert isinstance(t2, int) dims = _init_dims(x.shape, t1, t2) tiles = _array_to_2d_tile_array(x, tiling) # pyrefly: ignore[bad-argument-type] transforms = get_einshape_transforms(equation, x.shape, **size_vars) def get_outer_shape(dims_list: list[list[Factor]]) -> tuple[int, ...]: return tuple( math.prod([f.size for f in d if f.kind == "outer"]) for d in dims_list ) for t in transforms: match t: case Transpose(permutation): tiles = np.transpose(tiles, permutation) dims = [dims[i] for i in permutation] case SplitDims(index, sizes): new_dims = _apply_split(dims[index], sizes) assert ( new_dims is not None ), "Tile preserving check passed but split failed." dims = dims[:index] + new_dims + dims[index + 1 :] tiles = tiles.reshape(get_outer_shape(dims)) case MergeDims(index, count): merged = [b for d in dims[index : index + count] for b in d] dims = dims[:index] + [_consolidate(merged)] + dims[index + count :] tiles = tiles.reshape(get_outer_shape(dims)) return _2d_tile_array_to_array(tiles) def _is_tile_preserving( shape: tuple[int, ...], transforms: Sequence[Transform], tiling: tuple[int, int] | None = None, ) -> bool: if not tiling or len(shape) < 2: return False t1, t2 = tiling if shape[-2] % t1 != 0 or shape[-1] % t2 != 0: return False dims = _init_dims(shape, t1, t2) for t in transforms: match t: case SplitDims(index, sizes): if (new_dims := _apply_split(dims[index], sizes)) is None: return False dims[index : index + 1] = new_dims case MergeDims(index, count): merged = [b for d in dims[index : index + count] for b in d] dims[index : index + count] = [_consolidate(merged)] case Transpose(permutation): dims = [dims[i] for i in permutation] if len(dims) < 2: return False # Check that the last two dimensions are tiled along (sublane, lane). y_dim = dims[-2] if not y_dim or y_dim[-1] != Factor(t1, "sublane"): return False x_dim = dims[-1] if not x_dim or x_dim[-1] != Factor(t2, "lane"): return False return True def _einshape_kernel( equation: str, x: jax_typing.Array, assert_is_tile_preserving: bool, **size_vars: int, ): transforms = get_einshape_transforms(equation, x.shape, **dict(size_vars)) if len(transforms) <= 1: return _default_einshape_kernel(equation, x, **size_vars) tiling = tpu_info.infer_tiling(jax_core.ShapedArray(x.shape, x.dtype)) if _is_tile_preserving(x.shape, transforms, tiling[-2:]): # pyrefly: ignore[bad-argument-type] return _tile_preserving_einshape_kernel(equation, x, **size_vars) elif assert_is_tile_preserving: raise ValueError( "Tile preserving check failed for einshape kernel with equation:" f" {equation} and shape {x.shape} and tiling {tiling}." ) return _default_einshape_kernel(equation, x, **size_vars) @tpu_lowering.register_lowering_rule(einshape_lo_p) def _einshape_lo_lowering_rule( ctx: tpu_lowering.LoweringRuleContext, x, *, equation: str, sizes: tuple[tuple[str, int], ...], assert_is_tile_preserving: bool, ): return tpu_lowering.lower_fun( lambda x: _einshape_kernel( equation, x, assert_is_tile_preserving=assert_is_tile_preserving, **dict(sizes), ), )(ctx, x)