# Copyright 2026 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """APIs for defining MPMD kernels in Pallas.""" from __future__ import annotations from collections.abc import Callable, Mapping, Sequence import contextlib import functools from typing import cast, Any, ParamSpec, TypeVar from jax._src import api from jax._src import api_util from jax._src import config from jax._src import core as jax_core from jax._src import linear_util as lu from jax._src import state from jax._src import tree_util from jax._src import util from jax._src.frozen_dict import FrozenDict from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.pallas import core as pallas_core from jax._src.pallas import pallas_call _P = ParamSpec("_P") _T = TypeVar("_T") mpmd_map_p = jax_core.Primitive("mpmd_map") mpmd_map_p.multiple_results = True @mpmd_map_p.def_impl def _mpmd_map_impl(*args, **params): jit_impl = api.jit(functools.partial(mpmd_map_p.bind, **params)) with config.disable_jit(False): return jit_impl(*args) @mpmd_map_p.def_effectful_abstract_eval def _mpmd_map_abstract_eval( *in_avals, jaxprs, out_avals, input_output_aliases, interpret, compiler_params, **params, ): del params # Unused. effs = {*pallas_core.get_interpret_effects(interpret)} if getattr(compiler_params, "has_side_effects", False): # TODO(slebedev): Fix internal breakages and add # ``jax_core.GenericEffect(pallas_call_p)`` here. effs = jax_core.no_effects for jaxpr in jaxprs: if not all(isinstance(aval, state.AbstractRef) for aval in jaxpr.in_avals): raise TypeError("MPMD kernels must only have Ref inputs") # TODO(slebedev): Handle pinned buffers as in ``pallas_call``. outin_aliases = { out_idx: in_idx for in_idx, out_idx in input_output_aliases.items() } out_avals = [ in_avals[outin_aliases[out_idx]] if out_idx in outin_aliases else a for out_idx, a in enumerate(out_avals) ] # Make sure we don't return ShapedArray with pallas memory space to the # outside world. out_avals = tuple(a.update(memory_space=jax_core.MemorySpace.Device) if isinstance(a, jax_core.ShapedArray) else a for a in out_avals) return out_avals, effs def _mpmd_map_typecheck_rule(ctx_factory, *in_atoms, meshes, **params): del ctx_factory # Unused. ctx = contextlib.ExitStack() for mesh in meshes: ctx.enter_context(jax_core.extend_axis_env_nd(mesh.shape.items())) with ctx: return _mpmd_map_abstract_eval( *(x.aval for x in in_atoms), meshes=meshes, **params ) jax_core.custom_typechecks[mpmd_map_p] = _mpmd_map_typecheck_rule def _mpmd_map_tpu_lowering( ctx: mlir.LoweringRuleContext, *in_nodes, jaxprs, grid_mappings, meshes, input_output_aliases, debug, interpret, compiler_params, cost_estimate, out_avals, metadata, name, ): try: from jax._src.pallas.mosaic import pallas_call_registration except ImportError: raise pallas_call._unsupported_lowering_error("tpu") return pallas_call_registration.mpmd_map_tpu_lowering_rule( ctx, *in_nodes, jaxprs=jaxprs, grid_mappings=grid_mappings, meshes=meshes, input_output_aliases=input_output_aliases, debug=debug, interpret=interpret, compiler_params=compiler_params, cost_estimate=cost_estimate, out_avals=out_avals, metadata=metadata, name=name, ) def _mpmd_map_fallback_lowering( ctx: mlir.LoweringRuleContext, *in_nodes, meshes, jaxprs, grid_mappings, out_avals, input_output_aliases, compiler_params, interpret, debug, cost_estimate, metadata, name, ): if len(jaxprs) != 1: raise NotImplementedError( "Lowering multiple mesh/function pairs is not currently supported" ) [jaxpr] = jaxprs [mesh] = meshes [grid_mapping] = grid_mappings if hasattr(mesh, "dimension_semantics"): compiler_params = compiler_params.replace( dimension_semantics=mesh.dimension_semantics ) if hasattr(mesh, "kernel_type"): compiler_params = compiler_params.replace(kernel_type=mesh.kernel_type) return pallas_call._pallas_call_lowering( ctx, *in_nodes, jaxpr=jaxpr, grid_mapping=grid_mapping, mesh=mesh, input_output_aliases=tuple(input_output_aliases.items()), debug=debug, interpret=interpret, compiler_params=compiler_params, cost_estimate=cost_estimate, out_avals=out_avals, metadata=metadata, name=name, ) @functools.partial(mlir.register_lowering, mpmd_map_p) def _mpmd_map_lowering(ctx: mlir.LoweringRuleContext, *in_nodes, **params): return mlir.lower_per_platform( ctx, "mpmd_map", dict( cpu=_mpmd_map_fallback_lowering, tpu=_mpmd_map_tpu_lowering, cuda=_mpmd_map_fallback_lowering, rocm=_mpmd_map_fallback_lowering, ), None, # default_rule jax_core.no_effects, *in_nodes, **params, ) def mpmd_map( meshes_and_fns: Sequence[tuple[pallas_core.Mesh, Callable[_P, _T]]], /, out_shapes: tree_util.PyTree, *, scratch_shapes: pallas_core.ScratchShapeTree = (), compiler_params: Any | None = None, interpret: bool | Any = False, debug: bool = False, cost_estimate: pallas_core.CostEstimate | None = None, name: str | None = None, metadata: dict[str, str] | None = None, ) -> Callable[_P, _T]: return _mpmd_map( meshes_and_fns, out_shapes, input_output_aliases={}, scratch_shapes=scratch_shapes, compiler_params=compiler_params, interpret=interpret, debug=debug, cost_estimate=cost_estimate, name=name, metadata=metadata, ) def _mpmd_map( meshes_and_fns: Sequence[tuple[pallas_core.Mesh, Callable[_P, _T]]], /, out_shapes: tree_util.PyTree, *, input_output_aliases: Mapping[int, int] = {}, scratch_shapes: pallas_core.ScratchShapeTree = (), compiler_params: Any | None = None, interpret: bool | Any = False, debug: bool = False, cost_estimate: pallas_core.CostEstimate | None = None, name: str | None = None, metadata: dict[str, str] | None = None, ) -> Callable[_P, _T]: """Like ``pallas_call``, but MPMD and without pipelining.""" if not meshes_and_fns: raise ValueError("At least one mesh/function pair is required") flat_out_shapes_with_paths, out_tree = tree_util.tree_flatten_with_path( out_shapes ) out_paths, flat_out_shapes = util.unzip2(flat_out_shapes_with_paths) flat_out_avals = tuple( map(pallas_core._convert_out_shape_to_aval, flat_out_shapes) ) out_origins = tuple(f"outputs{tree_util.keystr(p)}" for p in out_paths) @functools.partial(api.jit, inline=True) def wrapper(*args): flat_args_with_paths, in_tree = tree_util.tree_flatten_with_path(args) in_paths, flat_args = util.unzip2(flat_args_with_paths) flat_avals = tuple(map(jax_core.typeof, flat_args)) in_origins = tuple(f"args{tree_util.keystr(p)}" for p in in_paths) # NOTE: ``grid_mapping`` are only needed for us to reuse the ``pallas_call`` # lowering machinery. meshes = [] jaxprs = [] grid_mappings = [] flat_scratch_shapes, scratch_tree = tree_util.tree_flatten(scratch_shapes) if len(meshes_and_fns) > 1: # TODO(rdyro): For MPMD with more than one mesh, come up with a better # solution for how to enforce core_type presence in scratch_shape. # TODO(rdyro): Check if we need to have a similar check for in-kernel # allocations (e.g., run_scoped, empty_ref) or can we assume the # core_type is inherited from the caller (we then need the core_type in # the caller context during tracing). # TODO(rdyro): Also check inputs and outputs for core type. for scratch_shape in flat_scratch_shapes: from jax._src.pallas.mosaic import core as tpu_core if (not isinstance(scratch_shape.memory_space, tpu_core.CoreMemorySpace) and scratch_shape.memory_space not in ( tpu_core.MemorySpace.HBM, tpu_core.MemorySpace.VMEM_SHARED)): raise NotImplementedError( "MPMD map with more than one mesh requires scratch_shape to have" f" a `core_type` specified, but {scratch_shape=} is missing it." ) # Check that meshes are compatible with each other (e.g, have a consistent # core axis name in the sparsecore). for i, (mesh, _) in enumerate(meshes_and_fns): for other_mesh, _ in list(meshes_and_fns)[i+1:]: mesh.check_is_compatible_with(other_mesh) super_mesh_shape = {} for mesh, _ in meshes_and_fns: for k, v in mesh.shape.items(): # An extra check since `check_is_compatible_with` should catch it. assert k not in super_mesh_shape or super_mesh_shape[k] == v, ( f"Conflicting size for axis {k}" ) super_mesh_shape[k] = v for mesh, fn in meshes_and_fns: grid_spec = pallas_core.GridSpec( grid=tuple(mesh.shape.items()), # pyrefly: ignore[bad-argument-type] in_specs=in_tree.unflatten( pallas_core.BlockSpec( memory_space=aval.memory_space if isinstance(aval, jax_core.ShapedArray) and not isinstance(aval.memory_space, jax_core.MemorySpace) else mesh.default_memory_space, ) for aval in flat_avals ), out_specs=out_tree.unflatten( pallas_core.BlockSpec( memory_space=aval.memory_space if isinstance(aval, jax_core.ShapedArray) and not isinstance(aval.memory_space, jax_core.MemorySpace) else mesh.default_memory_space, ) for aval in flat_out_avals ), scratch_shapes=flat_scratch_shapes, ) kernel_args, grid_mapping = pallas_core.get_grid_mapping( grid_spec, flat_avals, in_tree, in_origins, flat_out_avals, out_tree, out_origins, ) kernel_args, scratch_args = util.split_list( kernel_args, [len(kernel_args) - scratch_tree.num_leaves]) scratch_args = scratch_tree.unflatten(scratch_args) if isinstance(scratch_args, dict): kernel_args_kwargs = (kernel_args, scratch_args) else: kernel_args_kwargs = (kernel_args + list(scratch_args), {}) flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten( kernel_args_kwargs) debug_info = api_util.debug_info( "mpmd_map", fn, *kernel_args_kwargs, ) if name is not None: debug_info = debug_info.replace_func_name(name) flat_fun, out_tree_thunk = api_util.flatten_fun( lu.wrap_init(fn, debug_info=debug_info), kernel_in_tree ) with (jax_core.extend_axis_env_nd(super_mesh_shape.items()), config._check_vma(False)): jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( flat_fun, flat_kernel_avals ) fun_out_tree = out_tree_thunk() if fun_out_tree != tree_util.tree_structure(None): raise ValueError( f"The kernel function in mpmd_map {debug_info.func_src_info}" f" should return None. It returns a PyTree: {fun_out_tree}." ) if consts: raise NotImplementedError("MPMD kernels cannot close over constants") meshes.append(mesh) jaxprs.append(jaxpr) grid_mappings.append(grid_mapping) # TODO(slebedev): The named scope should not be necessary here. ctx = ( api.named_scope(name) if name is not None else contextlib.nullcontext() ) with ctx: flat_outs = mpmd_map_p.bind( *flat_args, meshes=tuple(meshes), jaxprs=tuple(jaxprs), grid_mappings=tuple(grid_mappings), out_avals=flat_out_avals, input_output_aliases=FrozenDict(input_output_aliases), compiler_params=compiler_params, interpret=interpret, debug=debug, cost_estimate=cost_estimate, metadata=FrozenDict(metadata) if metadata is not None else None, name=name, ) return out_tree.unflatten(flat_outs) return cast(Callable[_P, _T], wrapper)