Files
2026-05-06 19:47:31 +07:00

2499 lines
106 KiB
Python

# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections import defaultdict
from collections.abc import Callable, Sequence, Iterable
import contextlib
from dataclasses import dataclass, replace
from functools import partial
import inspect
import itertools as it
import weakref
from typing import NamedTuple, Any, Union
import warnings
import numpy as np
from jax._src import api
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import mesh as mesh_lib
from jax._src import op_shardings
from jax._src import profiler
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import stages
from jax._src import traceback_util
from jax._src import tree_util
from jax._src import util
from jax._src import xla_bridge as xb
from jax._src.core import typeof, cur_qdd
from jax._src.api_util import (
flatten_axes, donation_vector, check_callable, resolve_argnums, debug_info,
check_no_aliased_ref_args, _check_no_aliased_closed_over_refs,
flatten_axis_resources)
from jax._src.interpreters import partial_eval as pe
from jax._src.partition_spec import PartitionSpec
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.interpreters import remat
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib import xla_client as xc
from jax._src.mesh import AbstractMesh
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
NamedSharding, GSPMDSharding,
make_single_device_sharding, AUTO, UNSPECIFIED, UnspecifiedValue,
prepare_axis_resources, parse_flatten_op_sharding, canonicalize_sharding,
_internal_use_concrete_mesh)
from jax._src.layout import Format, Layout, AutoLayout, get_layout_for_vmap
from jax._src.state.types import RefEffect
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import (
tree_flatten, tree_unflatten, tree_structure, treedef_children,
PyTreeDef, none_leaf_registry as none_lr, tree_map, FlatTree)
from jax._src.typing import Array, ArrayLike
from jax._src.util import (
HashableFunction, safe_map, safe_zip, wraps, distributed_debug_log,
split_list, weakref_lru_cache, merge_lists, subs_list, fun_name)
from jax._src.lib import jax_jit
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
traceback_util.register_exclusion(__file__)
PjitSharding = Union[GSPMDSharding, UnspecifiedValue, AUTO]
PjitShardingMinusUnspecified = Union[GSPMDSharding, AUTO]
MeshSharding = Union[NamedSharding, UnspecifiedValue, AUTO]
MeshShardingMinusUnspecified = Union[NamedSharding, AUTO]
class PjitInfo(NamedTuple):
"""Things that we know about a jit instance before it is called.
In other words, this structure contains arguments to jit()/pjit(),
preprocessed and validated.
"""
fun_sourceinfo: str
fun_signature: inspect.Signature | None
# Shardings, as specified by the user. These can either be UNSPECIFIED or they
# can be a tree (prefix) of shardings or None.
user_specified_in_shardings: bool
in_shardings_treedef: PyTreeDef
in_shardings_leaves: tuple[Any, ...]
out_shardings_treedef: PyTreeDef
out_shardings_leaves: tuple[Any, ...]
in_layouts_treedef: PyTreeDef
in_layouts_leaves: tuple[Any, ...]
out_layouts_treedef: PyTreeDef
out_layouts_leaves: tuple[Any, ...]
static_argnums: tuple[int, ...]
static_argnames: tuple[str, ...]
donate_argnums: tuple[int, ...]
donate_argnames: tuple[str, ...]
device: xc.Device | None
backend: str | None
keep_unused: bool
inline: bool
use_resource_env: bool # False for jit, True for pjit
compiler_options_kvs: tuple[tuple[str, Any], ...]
# Hash and compare PjitInfo by identity when used as a cache key.
def __hash__(self):
return id(self)
def __eq__(self, other):
return self is other
def _run_python_pjit(p, args_flat, fun: Callable, args, kwargs):
for arg in args_flat:
dispatch.check_arg(arg)
try:
if (core.trace_state_clean() and not config.debug_key_reuse.value
and not p.params['jaxpr'].jaxpr.is_high):
args_flat = map(core.full_lower, args_flat)
core.check_eval_args(args_flat)
out_flat, compiled, profiler, const_args = _pjit_call_impl_python(
*args_flat, **p.params)
else:
out_flat = jit_p.bind(*args_flat, **p.params)
compiled = None
profiler = None
const_args = []
except stages.DeviceAssignmentMismatchError as e:
fails, = e.args
fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun)))
arg_types = map(convert_to_metaty, args_flat)
msg = stages._device_assignment_mismatch_error(
fun_name, fails, arg_types, 'jit', p.arg_names)
raise ValueError(msg) from None
except dtypes.InvalidInputException as e:
arg_names = [''] * len(args_flat) if p.arg_names is None else p.arg_names
# Run canonicalization again to figure out which arg failed.
if p.params['jaxpr'].consts:
raise TypeError(e.args[0]) from e
else:
for arg, name, aval in zip(args_flat, arg_names, p.in_avals):
try:
dtypes.canonicalize_value(arg)
except dtypes.InvalidInputException as _:
# Reraise as TypeError with the new message.
raise TypeError(
f"Argument '{name}' of shape {aval.str_short()} of type"
f' {type(arg)} is not a valid JAX type.') from e
raise AssertionError("Unreachable") from e
except api_util.InternalFloatingPointError as e:
if getattr(fun, '_apply_primitive', False):
raise FloatingPointError(
f"invalid value ({e.ty}) encountered in {fun.__qualname__}") from None
api_util.maybe_recursive_nan_check(e, fun, args, kwargs) # should always raise.
raise RuntimeError("Internal error") from e # fall-back error to be safe.
outs = tree_unflatten(p.out_tree, out_flat)
return (outs, out_flat, p.out_tree, args_flat,
p.params['jaxpr'], compiled, profiler, const_args)
def _need_to_rebuild_with_fdo(pgle_profiler):
return (pgle_profiler is not None and pgle_profiler.is_enabled()
and not pgle_profiler.is_fdo_consumed())
def _get_fastpath_data(
executable, out_tree, args_flat, out_flat, effects, consts_for_constvars,
pgle_profiler, const_args: Sequence[ArrayLike]
) -> pxla.MeshExecutableFastpathData | None:
if (
executable is None
or not isinstance(executable, pxla.MeshExecutable)
or not isinstance(executable.unsafe_call, pxla.ExecuteReplicated)
# No effects in computation
or executable.unsafe_call.ordered_effects
or executable.unsafe_call.has_unordered_effects
# no ref state effects
or any(isinstance(e, RefEffect) for e in effects)
or _need_to_rebuild_with_fdo(pgle_profiler)
or config.no_execution.value
):
return None
out_reflattened, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat)
if not all(isinstance(x, xc.ArrayImpl) for x in out_reflattened):
return None
out_avals = [o.aval for o in out_reflattened]
out_committed = [o._committed for o in out_reflattened]
kept_var_bitvec = [i in executable._kept_var_idx
for i in range(len(const_args) + len(args_flat))]
in_shardings = [
sharding_impls.physical_sharding(a, s)
if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended)
else s
for s, a in zip(executable._in_shardings, executable.in_avals)
]
return pxla.MeshExecutableFastpathData(
executable.xla_executable, out_tree, in_shardings,
executable._out_shardings, out_avals, out_committed, kept_var_bitvec,
executable._dispatch_in_layouts, const_args)
# The entries are doubled here from the default 4096 because _pjit_call_impl
# also has a cpp dispatch path and that would double the number of entries in
# the global shared cache.
# This cache is only used for jit's with only fun. For example: jax.jit(f)
_cpp_pjit_cache_fun_only = xc._xla.PjitFunctionCache(capacity=8192)
# This cache is used for jit where extra arguments are defined other than the
# fun. For example: jax.jit(f, donate_argnums=...) OR
# jax.jit(f, out_shardings=...), etc. We don't use the same cache because the
# capacity might get full very fast because of all the jitted function in JAX
# which might evict train_step for example.
_cpp_pjit_cache_explicit_attributes = xc._xla.PjitFunctionCache(capacity=8192)
def _get_cpp_global_cache(contains_explicit_attributes: bool):
if contains_explicit_attributes:
return _cpp_pjit_cache_explicit_attributes
else:
return _cpp_pjit_cache_fun_only
def _cpp_pjit(fun: Callable, jit_info: PjitInfo):
@api_boundary
def cache_miss(*args, **kwargs):
# args do not include the const args
# See https://docs.jax.dev/en/latest/internals/constants.html.
if config.no_tracing.value:
raise RuntimeError(f"re-tracing function {jit_info.fun_sourceinfo} for "
"`jit`, but 'no_tracing' is set")
p, args_flat = _infer_params(fun, jit_info, args, kwargs)
(outs, out_flat, out_tree, args_flat, jaxpr,
executable, pgle_profiler, const_args) = _run_python_pjit(
p, args_flat, fun, args, kwargs)
maybe_fastpath_data = _get_fastpath_data(
executable, out_tree, args_flat, out_flat, jaxpr.effects, jaxpr.consts,
pgle_profiler, const_args)
return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)
cache_key = pxla.JitGlobalCppCacheKeys(
donate_argnums=jit_info.donate_argnums,
donate_argnames=jit_info.donate_argnames,
device=jit_info.device, backend=jit_info.backend,
in_shardings_treedef=jit_info.in_shardings_treedef,
in_shardings_leaves=jit_info.in_shardings_leaves,
out_shardings_treedef=jit_info.out_shardings_treedef,
out_shardings_leaves=jit_info.out_shardings_leaves,
in_layouts_treedef=jit_info.in_layouts_treedef,
in_layouts_leaves=jit_info.in_layouts_leaves,
out_layouts_treedef=jit_info.out_layouts_treedef,
out_layouts_leaves=jit_info.out_layouts_leaves,
compiler_options_kvs=jit_info.compiler_options_kvs)
cpp_pjit_f = xc._xla.pjit(
fun_name(fun), fun, cache_miss, jit_info.static_argnums,
jit_info.static_argnames, cache_key, tree_util.dispatch_registry,
pxla.cc_shard_arg,
_get_cpp_global_cache(cache_key.contains_explicit_attributes))
cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
cpp_pjitted_f._fun = fun # pyrefly: ignore[missing-attribute]
cpp_pjitted_f._jit_info = jit_info # pyrefly: ignore[missing-attribute]
cpp_jitted_f_class = type(cpp_pjitted_f)
cpp_jitted_f_class.clear_cache = jit_evict_fn
cpp_jitted_f_class.lower = jit_lower
cpp_jitted_f_class.trace = jit_trace
cpp_jitted_f_class.eval_shape = jit_eval_shape
return cpp_pjitted_f
@api_boundary
def jit_trace(jit_func, *args, **kwargs) -> stages.Traced:
p, args_flat = _infer_params(jit_func._fun, jit_func._jit_info, args, kwargs)
arg_types = map(convert_to_metaty, args_flat)
return stages.Traced(arg_types, p.params, p.in_tree, p.out_tree, p.consts)
@api_boundary
def jit_lower(jit_func, *args, **kwargs):
return jit_trace(jit_func, *args, **kwargs).lower()
@api_boundary
def jit_eval_shape(jit_func, *args, **kwargs):
return jit_trace(jit_func, *args, **kwargs).out_info
def jit_evict_fn(self):
self._clear_cache()
pe.trace_to_jaxpr.evict_weakref(self._fun)
_infer_params_cached.cache_clear()
def _split_layout_and_sharding(entries):
entries_flat, treedef = tree_flatten(entries, is_leaf=lambda x: x is None)
layouts, shardings = [], []
for e in entries_flat:
if isinstance(e, Format):
layouts.append(e.layout)
shardings.append(e.sharding)
elif isinstance(e, (Layout, AutoLayout)):
raise ValueError(
'`jax.jit` does not accept device-local layouts directly. Create '
'a `Format` instance wrapping this device-local layout and pass '
f'that to `jit` instead. Got {e}')
else:
layouts.append(None)
shardings.append(e)
assert len(layouts) == len(shardings)
return tree_unflatten(treedef, layouts), tree_unflatten(treedef, shardings)
def _parse_jit_arguments(fun: Callable, *, in_shardings: Any,
out_shardings: Any,
static_argnums: int | Sequence[int] | None,
static_argnames: str | Iterable[str] | None,
donate_argnums: int | Sequence[int] | None,
donate_argnames: str | Iterable[str] | None,
keep_unused: bool, device: xc.Device | None,
backend: str | None, inline: bool,
compiler_options: dict[str, Any] | None,
use_resource_env: bool) -> PjitInfo:
"""Parses the arguments to jit/pjit.
Performs any preprocessing and validation of the arguments that we can do
ahead of time before the jit()-ed function is invoked.
"""
check_callable(fun)
if backend is not None or device is not None:
warnings.warn(
'backend and device argument on jit is deprecated. You can use'
' `jax.device_put(..., jax.local_devices(backend="cpu")[0])` on the'
' inputs to the jitted function to get the same behavior.',
DeprecationWarning,
)
if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
f"got {device=} and {backend=}")
if in_shardings is not None and not isinstance(in_shardings, UnspecifiedValue):
raise ValueError('If backend or device is specified on jit, then '
'in_shardings should not be specified.')
if out_shardings is not None and not isinstance(out_shardings, UnspecifiedValue):
raise ValueError('If backend or device is specified on jit, then '
'out_shardings should not be specified.')
if isinstance(in_shardings, list):
# To be a tree prefix of the positional args tuple, in_axes can never be a
# list: if in_axes is not a leaf, it must be a tuple of trees. However,
# in cases like these users expect tuples and lists to be treated
# essentially interchangeably, so we canonicalize lists to tuples here
# rather than raising an error. https://github.com/jax-ml/jax/issues/2367
in_shardings = tuple(in_shardings)
in_layouts, in_shardings = _split_layout_and_sharding(in_shardings)
out_layouts, out_shardings = _split_layout_and_sharding(out_shardings)
in_shardings = prepare_axis_resources(in_shardings, 'in_shardings')
out_shardings = prepare_axis_resources(out_shardings, 'out_shardings',
allow_unconstrained_dims=True)
user_specified_in_shardings = (in_shardings is not None and
not isinstance(in_shardings, UnspecifiedValue))
in_shardings_leaves, in_shardings_treedef = none_lr.flatten(in_shardings)
out_shardings_leaves, out_shardings_treedef = none_lr.flatten(out_shardings)
in_layouts_leaves, in_layouts_treedef = none_lr.flatten(in_layouts)
out_layouts_leaves, out_layouts_treedef = none_lr.flatten(out_layouts)
fun_sourceinfo = api_util.fun_sourceinfo(fun)
fun_signature = api_util.fun_signature(fun)
donate_argnums, donate_argnames, static_argnums, static_argnames = resolve_argnums(
fun, fun_signature, donate_argnums, donate_argnames, static_argnums,
static_argnames)
compiler_options_kvs = (() if compiler_options is None else
tuple(compiler_options.items()))
return PjitInfo(
fun_sourceinfo=fun_sourceinfo,
fun_signature=fun_signature,
user_specified_in_shardings=user_specified_in_shardings,
in_shardings_treedef=in_shardings_treedef,
in_shardings_leaves=tuple(in_shardings_leaves),
out_shardings_treedef=out_shardings_treedef,
out_shardings_leaves=tuple(out_shardings_leaves),
in_layouts_treedef=in_layouts_treedef,
in_layouts_leaves=tuple(in_layouts_leaves),
out_layouts_treedef=out_layouts_treedef,
out_layouts_leaves=tuple(out_layouts_leaves),
static_argnums=static_argnums,
static_argnames=static_argnames, donate_argnums=donate_argnums,
donate_argnames=donate_argnames, device=device, backend=backend,
keep_unused=keep_unused, inline=inline,
use_resource_env=use_resource_env,
compiler_options_kvs=compiler_options_kvs)
def make_jit(fun: Callable,
*,
in_shardings: Any,
out_shardings: Any,
static_argnums: int | Sequence[int] | None,
static_argnames: str | Iterable[str] | None,
donate_argnums: int | Sequence[int] | None,
donate_argnames: str | Iterable[str] | None,
keep_unused: bool,
device: xc.Device | None,
backend: str | None,
inline: bool,
compiler_options: dict[str, Any] | None,
use_resource_env: bool) -> Any:
"""jit() and pjit() are thin wrappers around this function."""
jit_info = _parse_jit_arguments(
fun, in_shardings=in_shardings, out_shardings=out_shardings,
static_argnums=static_argnums, static_argnames=static_argnames,
donate_argnums=donate_argnums, donate_argnames=donate_argnames,
keep_unused=keep_unused, device=device, backend=backend, inline=inline,
compiler_options=compiler_options,
use_resource_env=use_resource_env)
return _cpp_pjit(fun, jit_info)
class PjitParams(NamedTuple):
# Only jaxpr constants, we can't keep other arguments alive. These go as
# first arguments for `params['jaxpr']`.
consts: list[ArrayLike] # Corresponding to jaxpr.constvars
# Everything we need to trace, lower, and compile the jit function; passed
# to `pjit_call_impl_python`, along with the `args_flat`
params: dict[str, Any]
in_avals: tuple[core.AbstractValue, ...] # Not including the const_args
in_tree: PyTreeDef # Not including the const_args
out_tree: PyTreeDef
arg_names: tuple[str, ...] # Not including the const_args
def _trace_for_jit(
fun: Callable, ji: PjitInfo, ctx_mesh: mesh_lib.Mesh,
dbg: core.DebugInfo, avals, args, kwargs) -> PjitParams:
args_ft = FlatTree.flatten_static_argnums_argnames(
args, kwargs, ji.static_argnums, ji.static_argnames)
avals_ft = args_ft.update(avals)
has_kwargs = bool(kwargs)
if has_kwargs and ji.user_specified_in_shardings:
raise ValueError(
"pjit does not support kwargs when in_shardings is specified.")
if not ctx_mesh.empty and (ji.backend or ji.device):
raise ValueError(
"Mesh context manager should not be used with jit when backend or "
"device is also specified as an argument to jit.")
if (ji.donate_argnums or ji.donate_argnames) and not config.debug_nans.value:
donated_invars = donation_vector(ji.donate_argnums, ji.donate_argnames,
avals_ft.tree)
else:
donated_invars = (False,) * len(avals_ft)
# If backend or device is set as an arg on jit, then resolve them to
# in_shardings and out_shardings as if user passed in in_shardings
# and out_shardings.
device_or_backend_set = bool(ji.backend or ji.device)
if device_or_backend_set:
sharding = _create_sharding_with_device_backend(ji.device, ji.backend)
leaves, treedef = tree_flatten(sharding)
in_shardings_leaves = out_shardings_leaves = tuple(leaves)
in_shardings_treedef = out_shardings_treedef = treedef
else:
api_name = 'pjit' if ji.use_resource_env else 'jit'
in_shardings_leaves = tuple(
_create_sharding_for_array(ctx_mesh, x, 'in_shardings', api_name)
for x in ji.in_shardings_leaves)
out_shardings_leaves = tuple(
_create_sharding_for_array(ctx_mesh, x, 'out_shardings', api_name)
for x in ji.out_shardings_leaves)
in_shardings_treedef = ji.in_shardings_treedef
out_shardings_treedef = ji.out_shardings_treedef
assert None not in in_shardings_leaves
assert None not in out_shardings_leaves
in_type = avals_ft.map2(
lambda a, x: core.AvalQDD(a, cur_qdd(x)) if a.has_qdd else a,
args_ft)
assert avals_ft is not None
in_shardings_flat, in_layouts_flat = _process_in_axis_resources(
in_shardings_treedef, in_shardings_leaves,
ji.in_layouts_treedef, ji.in_layouts_leaves,
avals_ft, dbg, device_or_backend_set, has_kwargs)
qdd_token = _qdd_cache_index(fun, in_type.vals) # represents qdd state context
elapsed_time_ctx = (
dispatch.log_elapsed_time(
"Finished tracing {fun_name} for jit in {elapsed_time:.9f} sec",
fun_name(fun), event=dispatch.JAXPR_TRACE_EVENT)
if core.trace_state_clean() else contextlib.nullcontext())
with elapsed_time_ctx:
if ji.use_resource_env: # pjit
with (_internal_use_concrete_mesh(ctx_mesh),
mesh_lib.use_abstract_mesh(ctx_mesh.abstract_mesh)):
jaxpr, out_avals = pe.trace_to_jaxpr(fun, in_type, dbg, qdd_token)
else:
jaxpr, out_avals = pe.trace_to_jaxpr(fun, in_type, dbg, qdd_token)
if config.debug_key_reuse.value:
# Import here to avoid circular imports
from jax.experimental.key_reuse._core import check_key_reuse_jaxpr # pytype: disable=import-error
check_key_reuse_jaxpr(jaxpr.jaxpr)
result_paths = tuple(f"result{lu._clean_keystr_arg_names(path)}"
for path in out_avals.paths)
jaxpr.jaxpr._debug_info = jaxpr.debug_info._replace(result_paths=result_paths)
# TODO(mattjj,yashkatariya): if we take the 'true' path then we *must* fall
# off the C++ dispatch fast path for correctness. Ensure that happens.
if any(isinstance(c, core.Tracer) or core.typeof(c).has_qdd for c in jaxpr.consts):
jaxpr, consts = pe.separate_consts(jaxpr)
else:
consts = []
if config.mutable_array_checks.value:
_check_no_aliased_closed_over_refs(dbg, (*jaxpr.consts, *consts), args_ft.vals)
_qdd_cache_update(fun, in_type.vals, qdd_token, consts,
jaxpr.in_aval_qdds[:len(consts)])
out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings(
out_shardings_treedef, out_shardings_leaves, ji.out_layouts_treedef,
ji.out_layouts_leaves, out_avals.tree,
tuple(out_avals), jaxpr.jaxpr._debug_info, device_or_backend_set)
assert len(args_ft.vals) == len(in_shardings_flat) == len(in_layouts_flat)
num_extra_args = len(consts)
in_shardings_flat = (UNSPECIFIED,) * num_extra_args + in_shardings_flat
in_layouts_flat = (None,) * num_extra_args + in_layouts_flat
donated_invars = (False,) * num_extra_args + donated_invars
assert (len(in_shardings_flat) == len(in_layouts_flat) ==
len(donated_invars) == len(consts) + len(avals_ft))
params = dict(
jaxpr=jaxpr,
in_shardings=in_shardings_flat,
out_shardings=out_shardings_flat,
in_layouts=in_layouts_flat,
out_layouts=out_layouts_flat,
donated_invars=donated_invars,
ctx_mesh=ctx_mesh,
name=fun_name(fun),
keep_unused=ji.keep_unused,
inline=ji.inline,
compiler_options_kvs=ji.compiler_options_kvs,
)
return PjitParams(consts, params, avals_ft.vals, avals_ft.tree_without_statics,
out_avals.tree, dbg.safe_arg_names(len(avals_ft)))
@dataclass(slots=True)
class InferParamsCacheEntry:
pjit_params: PjitParams | None = None
@weakref_lru_cache
def _infer_params_cached(
fun: Callable, jit_info: PjitInfo, signature: jax_jit.ArgumentSignature,
in_avals: tuple[core.AbstractValue, ...], ctx_mesh: mesh_lib.Mesh
) -> InferParamsCacheEntry:
return InferParamsCacheEntry()
def get_ctx_mesh(use_resource_env):
if use_resource_env:
return mesh_lib.thread_resources.env.physical_mesh
else:
conc_mesh = mesh_lib.get_concrete_mesh()
if not conc_mesh.empty:
return conc_mesh
else:
abs_mesh = mesh_lib.get_abstract_mesh()
# TODO(yashkatariya): Make top-level use_abstract_mesh work with Auto mode
# too. But there are failures in user code so restricting it to Explicit
# mode for now.
if not abs_mesh.empty and abs_mesh._any_axis_explicit:
return abs_mesh
return conc_mesh
def _infer_params(
fun: Callable, ji: PjitInfo, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> tuple[PjitParams, list[core.Value]]:
ctx_mesh = get_ctx_mesh(ji.use_resource_env)
dbg_fn = lambda: debug_info(
'jit', fun, args, kwargs, static_argnums=ji.static_argnums,
static_argnames=ji.static_argnames, sourceinfo=ji.fun_sourceinfo,
signature=ji.fun_signature)
arg_signature, dynargs = jax_jit.parse_arguments(
args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums,
ji.static_argnames, tree_util.tracing_registry)
avals = _infer_input_type(fun, dbg_fn, dynargs)
entry = _infer_params_cached(fun, ji, arg_signature, avals, ctx_mesh)
if entry.pjit_params is not None:
return entry.pjit_params, entry.pjit_params.consts + dynargs
p = _trace_for_jit(fun, ji, ctx_mesh, dbg_fn(), avals, args, kwargs)
if p.params['jaxpr'].jaxpr.is_high:
return p, p.consts + dynargs
entry.pjit_params = p
return p, p.consts + dynargs
def _infer_input_type(fun: Callable, dbg_fn: Callable[[], core.DebugInfo],
explicit_args) -> tuple[core.AbstractValue, ...]:
avals = []
i = -1
x = None
try:
for i, x in enumerate(explicit_args):
avals.append(core.shaped_abstractify(x))
except OverflowError:
dbg = dbg_fn()
arg_path = f"argument path is {dbg.arg_names[i] if dbg.arg_names is not None else 'unknown'}"
raise OverflowError(
"An overflow was encountered while parsing an argument to a jitted "
f"computation, whose {arg_path}. Got {type(x)} with value {x}"
) from None
except TypeError:
dbg = dbg_fn()
arg_description = f"path {dbg.arg_names[i] if dbg.arg_names is not None else 'unknown'}"
raise TypeError(
f"Error interpreting argument to {fun} as an abstract array."
f" The problematic value is of type {type(x)} and was passed to"
f" the function at {arg_description}.\n"
"This typically means that a jit-wrapped function was called with a non-array"
" argument, and this argument was not marked as static using the"
" static_argnums or static_argnames parameters of jax.jit."
) from None
if config.mutable_array_checks.value:
check_no_aliased_ref_args(dbg_fn, avals, explicit_args)
return tuple(avals)
class JitWrapped(stages.Wrapped):
def eval_shape(self, *args, **kwargs):
"""See ``jax.eval_shape``."""
raise NotImplementedError
def trace(self, *args, **kwargs) -> stages.Traced:
raise NotImplementedError
# in_shardings and out_shardings can't be None as the default value
# because `None` means that the input is fully replicated.
@partial(api_boundary, repro_api_name="pjit.pjit")
def pjit(
fun: Callable,
in_shardings: Any = UNSPECIFIED,
out_shardings: Any = UNSPECIFIED,
static_argnums: int | Sequence[int] | None = None,
static_argnames: str | Iterable[str] | None = None,
donate_argnums: int | Sequence[int] | None = None,
donate_argnames: str | Iterable[str] | None = None,
keep_unused: bool = False,
device: xc.Device | None = None,
backend: str | None = None,
inline: bool = False,
compiler_options: dict[str, Any] | None = None,
) -> JitWrapped:
"""`jax.experimental.pjit.pjit` has been deprecated. Please use `jax.jit`."""
return make_jit(
fun, in_shardings=in_shardings, out_shardings=out_shardings,
static_argnums=static_argnums, static_argnames=static_argnames,
donate_argnums=donate_argnums, donate_argnames=donate_argnames,
keep_unused=keep_unused, device=device, backend=backend, inline=inline,
compiler_options=compiler_options, use_resource_env=True)
def hashable_pytree(pytree):
vals, treedef = tree_flatten(pytree)
vals = tuple(vals)
return HashableFunction(lambda: tree_unflatten(treedef, vals),
closure=(treedef, vals))
def _create_sharding_for_array(mesh, x, name, api_name):
if x is None:
if api_name == 'jit' or mesh.empty:
return UNSPECIFIED
return sharding_impls.cached_named_sharding(mesh, PartitionSpec())
if isinstance(x, (AUTO, UnspecifiedValue, Sharding)):
return x
if mesh.empty:
raise RuntimeError(
f'{api_name} requires a non-empty mesh in context if you are passing'
f' `PartitionSpec`s to {name}. You can define a context mesh via'
' `jax.set_mesh(mesh)`. Alternatively, provide `Sharding`s to'
f' {name} and then the mesh context manager is not required.')
assert isinstance(x, PartitionSpec), x
return sharding_impls.cached_named_sharding(mesh, x)
def _create_sharding_with_device_backend(device, backend):
if device is not None:
assert backend is None
out = make_single_device_sharding(device)
elif backend is not None:
assert device is None
out = make_single_device_sharding(
xb.get_backend(backend).local_devices()[0])
else:
raise AssertionError('Unreachable!')
out._device_backend = True
return out
@util.cache(max_size=4096, trace_context_in_key=False)
def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
in_layouts_treedef, in_layouts_leaves,
in_avals, dbg: core.DebugInfo,
device_or_backend_set, kws):
if kws:
in_tree = in_avals.tree_without_statics
else:
in_tree, _ = treedef_children(in_avals.tree_without_statics)
orig_in_shardings = tree_unflatten(in_shardings_treedef, in_shardings_leaves)
# Only do this if original in_shardings are unspecified. If it is AUTO, go
# via flatten_axis_resources.
if isinstance(orig_in_shardings, UnspecifiedValue):
in_shardings_flat = (orig_in_shardings,) * len(in_avals)
else:
in_shardings_flat = flatten_axis_resources(
"pjit in_shardings", in_tree, orig_in_shardings, tupled_args=True)
in_layouts = tree_unflatten(in_layouts_treedef, in_layouts_leaves)
if in_layouts is None:
in_layouts_flat = (in_layouts,) * len(in_avals)
else:
in_layouts_flat = flatten_axis_resources(
"pjit in_layouts", in_tree, in_layouts, tupled_args=True)
pjit_check_aval_sharding(in_shardings_flat, in_avals,
dbg.safe_arg_names(len(in_avals)),
"pjit arguments", allow_uneven_sharding=False)
check_aval_layout_compatibility(
in_layouts_flat, in_avals,
dbg.safe_arg_names(len(in_avals)), "jit arguments")
return in_shardings_flat, in_layouts_flat
@util.cache(max_size=4096, trace_context_in_key=False)
def _check_and_canonicalize_out_shardings(
out_shardings_treedef, out_shardings_leaves, out_layouts_treedef,
out_layouts_leaves, out_tree, out_avals,
debug_info: core.DebugInfo,
device_or_backend_set):
orig_out_shardings = tree_unflatten(out_shardings_treedef, out_shardings_leaves)
if isinstance(orig_out_shardings, (UnspecifiedValue, Sharding)):
out_shardings_flat = (orig_out_shardings,) * len(out_avals)
else:
out_shardings_flat = flatten_axis_resources(
"pjit out_shardings", out_tree, orig_out_shardings,
tupled_args=False)
out_layouts = tree_unflatten(out_layouts_treedef, out_layouts_leaves)
if out_layouts is None:
out_layouts_flat = (out_layouts,) * len(out_avals)
else:
out_layouts_flat = flatten_axis_resources(
"pjit out_layouts", out_tree, out_layouts, tupled_args=False)
pjit_check_aval_sharding(
out_shardings_flat, out_avals,
debug_info.safe_result_paths(len(out_avals)),
"pjit outputs", allow_uneven_sharding=False)
check_aval_layout_compatibility(
out_layouts_flat, out_avals,
debug_info.safe_result_paths(len(out_avals)),
"jit outputs")
return out_shardings_flat, out_layouts_flat
_seen_qdds = weakref.WeakKeyDictionary()
def _seen_qdds_get(fun, in_type) -> list:
cache = _seen_qdds.setdefault(fun, defaultdict(list))
assert cache is not None # pyrefly#2407
return cache[in_type]
def _qdd_cache_index(fun, in_type) -> int:
cases = _seen_qdds_get(fun, in_type)
for i, records in enumerate(cases):
for obj, qdd in records:
if core.cur_qdd(obj) != qdd: break
else:
return i
return len(cases)
def _qdd_cache_update(fun, in_type, i, consts, aval_qdds):
cases = _seen_qdds_get(fun, in_type)
if i == len(cases):
cases.append([(c, aval_qdd.qdd) for c, aval_qdd in zip(consts, aval_qdds)
if aval_qdd.has_qdd])
@dataclass(frozen=True)
class IgnoreKey:
val: Any
def __hash__(self):
return hash(self.__class__)
def __eq__(self, other):
return isinstance(other, IgnoreKey) # ignore self.val!
def pjit_check_aval_sharding(
shardings, flat_avals, names: Sequence[str],
what_aval: str, allow_uneven_sharding: bool):
for aval, s, name in zip(flat_avals, shardings, names):
if isinstance(s, (UnspecifiedValue, AUTO)):
continue
name_str = f' with pytree key path {name}' if name else ''
shape = aval.shape
try:
s.check_compatible_aval(shape)
except ValueError as e:
raise ValueError(
f'One of {what_aval}{name_str} is incompatible with its sharding '
f'annotation {s}: {e}')
if not allow_uneven_sharding:
s.shard_shape(aval.shape) # will check for divisibility
def check_aval_layout_compatibility(
layouts, flat_avals, names: Sequence[str], what_aval: str):
for aval, l, name in zip(flat_avals, layouts, names):
if l is None or isinstance(l, AutoLayout):
continue
name_str = f' with pytree key path {name}' if name else ''
try:
l.check_compatible_aval(aval.shape)
except ValueError as e:
raise ValueError(
f'One of {what_aval}{name_str} is incompatible with its layout '
f'annotation {l}: {e}')
# -------------------- pjit rules --------------------
jit_p = core.Primitive("jit")
jit_p.is_effectful = lambda params: bool(params['jaxpr'].effects)
jit_p.multiple_results = True
jit_p.skip_canonicalization = True
def _is_high(*_, jaxpr, **__) -> bool:
return jaxpr.jaxpr.is_high
jit_p.is_high = _is_high
def _to_lojax(*hi_args, jaxpr, **params):
# convert closed-over boxes to explicit args
jaxpr, closed_over_himutables = pe.convert_const_himutables(jaxpr)
hi_args = [*closed_over_himutables, *hi_args]
params = _converted_mutables_add_params(len(closed_over_himutables), **params)
lo_args_lol = [aval.read_loval_in(x) if aval.has_qdd else aval.lower_val(x)
for aval, x in zip(jaxpr.in_aval_qdds, hi_args)]
lo_args = [x for xs in lo_args_lol for x in xs]
in_avals = FlatTree.flatten(([[typeof(x) for x in xs] for xs in lo_args_lol], {}))
lo_jaxpr, out_avals = pe.lower_jaxpr(jaxpr, in_avals)
params = _lojax_expand_params(in_avals, out_avals, **params)
all_outs = jit_p.bind(*lo_args, jaxpr=lo_jaxpr, **params)
out_mut, lo_outs = out_avals.update(all_outs).unpack()
for a, x, u in zip(jaxpr.final_aval_qdds, hi_args, out_mut.unpack()):
if a.has_qdd:
a.aval.update_from_loval2(a.qdd, x, u)
return [a.raise_val2(y) for a, y in zip(jaxpr.out_avals, lo_outs.unpack())]
jit_p.to_lojax = _to_lojax
def _converted_mutables_add_params(
n, *, donated_invars, in_shardings, in_layouts, **params):
donated_invars = (False,) * n + donated_invars
in_shardings = (UNSPECIFIED,) * n + in_shardings
in_layouts = (None,) * n + in_layouts
return dict(params, donated_invars=donated_invars, in_shardings=in_shardings,
in_layouts=in_layouts)
def _lojax_expand_params(
in_avals_, out_avals, donated_invars, in_shardings, in_layouts,
out_shardings, out_layouts, **params):
in_avals, () = in_avals_.unpack()
in_lol = in_avals.unpack()
mut_out_lol, out_lol_ = out_avals.unpack()
out_lol = out_lol_.unpack()
# some pjit params match the length of hi_jaxpr.invars/outvars, so when
# lowering we must expand them to match their number of lojax types
def expand(lol, stuff):
return tuple(x for l, x in zip(lol, stuff) for _ in l)
donated_invars = expand(in_lol , donated_invars)
in_shardings = expand(in_lol , in_shardings )
in_layouts = expand(in_lol , in_layouts )
out_shardings = expand(out_lol, out_shardings )
out_layouts = expand(out_lol, out_layouts )
# also, the lo_jaxpr has pure outputs corresponding to mutable hi_jaxpr types
num_muts_out = len(mut_out_lol) # it's a flat tree
out_shardings = (UNSPECIFIED,) * num_muts_out + out_shardings
out_layouts = (None,) * num_muts_out + out_layouts
new_params = dict(params, donated_invars=donated_invars,
in_shardings=in_shardings, in_layouts=in_layouts,
out_shardings=out_shardings, out_layouts=out_layouts)
return new_params
def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings,
in_avals) -> Sequence[Layout | AutoLayout | None]:
# If device or backend is set, return the default layout. This is because you
# can pass arrays on cpu (with untiled layouts) to jit with backend='tpu'
# which causes error checks to fail. Returning the default layout allows
# this to exist. It's the same for handling shardings.
if pxla.check_device_backend_on_shardings(resolved_in_shardings):
return (None,) * len(jit_in_layouts)
resolved_in_layouts: list[Layout | AutoLayout | None] = []
for arg, jit_in_l, rs, aval in safe_zip(
args, jit_in_layouts, resolved_in_shardings, in_avals):
committed = arg.committed
# `arg_layout` is only used for checking purposes in the `else` branch
# below. We cannot replace default layout with None to raise nicer errors.
# `dispatch_arg_layout` replaces default layouts with `None` to simplify
# dispatch and lowering logic downstream.
if arg.format is not None:
arg_layout = arg.format.layout
dispatch_arg_layout = (None if pxla.is_default_layout(arg_layout, rs, aval)
else arg_layout)
else:
arg_layout, dispatch_arg_layout = None, None
if jit_in_l is None:
if committed:
if isinstance(rs, UnspecifiedValue):
resolved_in_layouts.append(None)
else:
resolved_in_layouts.append(dispatch_arg_layout)
else:
resolved_in_layouts.append(None)
else:
# arg_layout can be None because some backends don't implement the
# required layout methods. Hence `arr.format` can return
# `Format(None, sharding)`
if (committed
and not isinstance(rs, UnspecifiedValue)
and arg_layout is not None
and not pxla.is_user_xla_layout_equal(jit_in_l, arg_layout)):
extra_msg = ''
if isinstance(jit_in_l, AutoLayout):
extra_msg = (
' The layout given to `jax.jit` is `Layout.AUTO` but'
' the corresponding argument passed is a `jax.Array` with a'
' concrete layout. Consider passing a `jax.ShapeDtypeStruct`'
' instead of `jax.Array` as an argument to the jitted function '
' when using `Layout.AUTO`.'
)
raise ValueError('Layout passed to jit does not match the layout '
'on the respective arg. '
f'Got jit layout: {jit_in_l},\n'
f'arg layout: {arg_layout} for arg type: {arg.aval}.'
f'{extra_msg}')
jit_in_l = (None if isinstance(jit_in_l, Layout) and
pxla.is_default_layout(jit_in_l, rs, aval) else jit_in_l)
resolved_in_layouts.append(jit_in_l)
return tuple(resolved_in_layouts)
def _resolve_out_layouts(out_layouts, out_shardings, out_avals):
new_out_layouts = []
for out_l, out_s, out_aval in safe_zip(out_layouts, out_shardings, out_avals):
if out_l is None:
new_out_layouts.append(None)
elif (isinstance(out_l, Layout) and
pxla.is_default_layout(out_l, out_s, out_aval)):
new_out_layouts.append(None)
else:
new_out_layouts.append(out_l)
return tuple(new_out_layouts)
def finalize_arg_sharding(arg_s, committed):
if isinstance(arg_s, UnspecifiedValue):
return arg_s
else:
if committed:
return arg_s
else:
assert isinstance(arg_s, Sharding)
if arg_s.num_devices == 1:
return UNSPECIFIED
raise NotImplementedError('Having uncommitted Array sharded on '
'multiple devices is not supported.')
def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
) -> Sequence[PjitSharding]:
# If True, means that device or backend is set by the user on pjit and it
# has the same semantics as device_put i.e. doesn't matter which device the
# arg is on, reshard it to the device mentioned. So don't do any of the
# checks and just return the pjit_in_shardings directly. `shard_args` will
# handle the resharding.
if pxla.check_device_backend_on_shardings(pjit_in_shardings):
return pjit_in_shardings
resolved_in_shardings: list[PjitSharding] = []
for arg, pjit_in_s in zip(args, pjit_in_shardings):
# arg sharding can be None in case of ShapeDtypeStruct. jax.Array does
# not allow None as the sharding.
arg_s, committed = ((arg.sharding, arg.committed) if arg.sharding is not None
else (UNSPECIFIED, False))
if isinstance(arg_s, NamedSharding) and arg_s.mesh.empty:
arg_s, committed = UNSPECIFIED, False
if isinstance(pjit_in_s, UnspecifiedValue):
resolved_in_shardings.append(finalize_arg_sharding(arg_s, committed))
else:
if (arg.is_np_array and not pjit_in_s.is_fully_replicated and # pyrefly: ignore[missing-attribute]
xb.process_count() > 1):
raise ValueError(
'Passing non-trivial shardings for numpy '
'inputs is not allowed. To fix this error, either specify a '
'replicated sharding explicitly or use '
'`jax.make_array_from_process_local_data(...)` '
'to convert your host local numpy inputs to a jax.Array which you '
'can pass to jit. '
'If the numpy input is the same on each process, then you can use '
'`jax.make_array_from_callback(...) to create a `jax.Array` which '
f'you can pass to jit. Got arg type: {arg.aval}')
if not isinstance(arg_s, UnspecifiedValue) and arg_s._is_concrete:
# jax.jit does not allow resharding across different memory kinds even
# if the argument is uncommitted. Use jax.device_put for those cases,
# either outside or inside jax.jit.
if pjit_in_s.memory_kind != arg_s.memory_kind: # pyrefly: ignore[missing-attribute]
raise ValueError(
'Memory kinds passed to jax.jit does not match memory kind on the'
f' respective arg. Got jit memory kind: {pjit_in_s.memory_kind}, '
f'arg memory kind: {arg_s.memory_kind} for arg type: {arg.aval}')
if (committed and
not op_shardings.are_hlo_shardings_equal(
pjit_in_s._to_xla_hlo_sharding(arg.ndim), # pyrefly: ignore[missing-attribute]
arg_s._to_xla_hlo_sharding(arg.ndim))):
raise ValueError('Sharding passed to jit does not match the sharding '
'on the respective arg. '
f'Got jit sharding: {pjit_in_s},\n'
f'arg sharding: {arg_s} for arg type: {arg.aval}')
resolved_in_shardings.append(pjit_in_s)
return tuple(resolved_in_shardings)
def _resolve_and_lower(
args, jaxpr: core.ClosedJaxpr, in_shardings, out_shardings, in_layouts,
out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline,
lowering_platforms, lowering_parameters, pgle_profiler,
compiler_options_kvs) -> pxla.MeshComputation:
in_shardings = _resolve_in_shardings(args, in_shardings)
in_layouts = _resolve_in_layouts(args, in_layouts, in_shardings,
jaxpr.in_avals)
out_layouts = _resolve_out_layouts(out_layouts, out_shardings, jaxpr.out_avals)
return _pjit_lower(
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs,
lowering_platforms=lowering_platforms,
lowering_parameters=lowering_parameters,
pgle_profiler=pgle_profiler)
_pgle_profiler_dict = weakref.WeakKeyDictionary()
@dataclass(frozen=True)
class MetaTy:
aval: Any
sharding: Any
format: Any
committed: bool
is_np_array: bool
replace = replace
@property
def shape(self):
return self.aval.shape
@property
def ndim(self):
return self.aval.ndim
@util.cache(max_size=4096, trace_context_in_key=False)
def create_meta_ty(aval, arg_sharding, arg_format, arg_committed, is_np_array):
return MetaTy(aval, arg_sharding, arg_format, arg_committed, is_np_array)
def convert_to_metaty(arg):
# TODO(yashkatariya): Remove this Tracer special case after
# getattr(Tracer, 'sharding') is fast.
if isinstance(arg, core.Tracer):
return create_meta_ty(arg.aval, None, None, True, False)
aval = core.shaped_abstractify(arg)
arg_sharding = getattr(arg, 'sharding', None)
arg_format = getattr(arg, 'format', None)
arg_committed = getattr(arg, '_committed', True)
is_np_array = isinstance(arg, np.ndarray)
return create_meta_ty(aval, arg_sharding, arg_format, arg_committed,
is_np_array)
def _pjit_call_impl_python(
*args,
jaxpr: core.ClosedJaxpr,
in_shardings, out_shardings, in_layouts, out_layouts,
donated_invars, ctx_mesh, name, keep_unused, inline,
compiler_options_kvs):
util.test_event("jit_cpp_cache_miss")
pgle_compile_options, pgle_profiler = {}, None
if config.enable_pgle.value and config.pgle_profiling_runs.value > 0:
compilation_target_key = jaxpr
pgle_profiler = _pgle_profiler_dict.get(compilation_target_key)
if pgle_profiler is None:
pgle_profiler = profiler.PGLEProfiler(
config.pgle_profiling_runs.value,
config.pgle_aggregation_percentile.value)
_pgle_profiler_dict[compilation_target_key] = pgle_profiler
# The method below will return FDO profile when module was profiled
# config.jax_pgle_profiling_runs amount of times, otherwise the result will
# be None.
fdo_profile = pgle_profiler.consume_fdo_profile()
if fdo_profile is not None:
pgle_compile_options['fdo_profile'] = fdo_profile
compiler_options_kvs = compiler_options_kvs + tuple(pgle_compile_options.items())
# Passing mutable PGLE profile here since it should be extracted by JAXPR to
# initialize the fdo_profile compile option.
arg_types = map(convert_to_metaty, args)
computation = _resolve_and_lower(
arg_types, jaxpr=jaxpr, in_shardings=in_shardings,
out_shardings=out_shardings, in_layouts=in_layouts,
out_layouts=out_layouts, donated_invars=donated_invars,
ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused,
inline=inline, lowering_platforms=None,
lowering_parameters=mlir.LoweringParameters(),
pgle_profiler=pgle_profiler,
compiler_options_kvs=compiler_options_kvs,
)
compiled = computation.compile()
# This check is expensive so only do it if enable_checks is on.
if compiled._auto_spmd_lowering and config.enable_checks.value:
pxla.check_array_xla_sharding_layout_match(
args, compiled._in_shardings, compiled._in_layouts, # pyrefly: ignore[missing-attribute]
jaxpr.jaxpr.debug_info.safe_arg_names(len(args)))
if config.distributed_debug.value:
# Defensively only perform fingerprint logic if debug logging is enabled
# NOTE(skyewm): I didn't benchmark this
fingerprint = None
if hasattr(compiled.runtime_executable(), "fingerprint"):
fingerprint = compiled.runtime_executable().fingerprint
if fingerprint is not None:
fingerprint = fingerprint.hex()
distributed_debug_log(("Running pjit'd function", name),
("in_shardings", in_shardings),
("out_shardings", out_shardings),
("in_layouts", in_layouts),
("out_layouts", out_layouts),
("abstract args", map(core.typeof, args)),
("fingerprint", fingerprint))
return (compiled.unsafe_call(*computation.const_args, *args),
compiled, pgle_profiler, computation.const_args)
@weakref_lru_cache
def _get_jaxpr_as_fun(jaxpr, in_shardings, out_shardings, in_layouts,
out_layouts, donated_invars, ctx_mesh, name,
keep_unused, inline, compiler_options_kvs):
# The input jaxpr to `_get_jaxpr_as_fun` is under a weakref_lru_cache so
# returning `core.jaxpr_as_fun(jaxpr)` directly creates a strong reference to
# the jaxpr defeating the purpose of weakref_lru_cache. So return a function
# that closes over a weakrefed jaxpr and gets called inside that function.
# This way there won't be a strong reference to the jaxpr from the output
# function.
jaxpr = weakref.ref(jaxpr)
return lambda *args: core.jaxpr_as_fun(jaxpr())(*args)
def _pjit_call_impl(*args, jaxpr: core.ClosedJaxpr,
in_shardings, out_shardings, in_layouts, out_layouts,
donated_invars, ctx_mesh, name, keep_unused, inline,
compiler_options_kvs):
def call_impl_cache_miss(*args_, **kwargs_):
# args_ do not include the const args
# See https://docs.jax.dev/en/latest/internals/constants.html.
# TODO(necula): remove num_const_args when fixing the C++ path
out_flat, compiled, pgle_profiler, const_args = _pjit_call_impl_python(
*args, jaxpr=jaxpr, in_shardings=in_shardings,
out_shardings=out_shardings, in_layouts=in_layouts,
out_layouts=out_layouts, donated_invars=donated_invars,
ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused,
inline=inline, compiler_options_kvs=compiler_options_kvs)
fastpath_data = _get_fastpath_data(
compiled, tree_structure(out_flat), args, out_flat,
jaxpr.effects, jaxpr.consts, pgle_profiler,
const_args)
return out_flat, fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)
f = _get_jaxpr_as_fun(
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
donated_invars, ctx_mesh, name, keep_unused, inline,
compiler_options_kvs)
donated_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
cache_key = pxla.JitGlobalCppCacheKeys(
donate_argnums=donated_argnums, donate_argnames=None,
device=None, backend=None,
in_shardings_treedef=None, in_shardings_leaves=in_shardings,
out_shardings_treedef=None, out_shardings_leaves=out_shardings,
in_layouts_treedef=None, in_layouts_leaves=in_layouts,
out_layouts_treedef=None, out_layouts_leaves=out_layouts)
return xc._xla.pjit(
name, f, call_impl_cache_miss, [], [], cache_key,
tree_util.dispatch_registry, pxla.cc_shard_arg,
_get_cpp_global_cache(cache_key.contains_explicit_attributes))(*args)
jit_p.def_impl(_pjit_call_impl)
# This cache is important for python dispatch performance.
@weakref_lru_cache
def _pjit_lower(
jaxpr: core.ClosedJaxpr,
in_shardings,
out_shardings,
in_layouts: pxla.MaybeLayout,
out_layouts: pxla.MaybeLayout,
donated_invars,
ctx_mesh,
name: str,
keep_unused: bool,
inline: bool,
compiler_options_kvs: tuple[tuple[str, Any], ...],
*,
lowering_platforms: tuple[str, ...] | None,
lowering_parameters: mlir.LoweringParameters,
pgle_profiler: profiler.PGLEProfiler | None) -> pxla.MeshComputation:
return pxla.lower_sharding_computation(
jaxpr, 'jit', name, in_shardings, out_shardings,
in_layouts, out_layouts, tuple(donated_invars),
keep_unused=keep_unused, context_mesh=ctx_mesh,
compiler_options_kvs=compiler_options_kvs,
lowering_platforms=lowering_platforms,
lowering_parameters=lowering_parameters,
pgle_profiler=pgle_profiler)
def pjit_staging_rule(trace, source_info, *args, **params):
if params["compiler_options_kvs"]:
raise ValueError(
'`compiler_options` can only be passed to top-level `jax.jit`. Got'
f' compiler_options={dict(params["compiler_options_kvs"])} specified on'
f' a nested jit with name: {params["name"]} and source info:'
f' {source_info_util.summarize(source_info)}')
# If we're inlining, no need to compute forwarding information; the inlined
# computation will in effect forward things.
if (params["inline"] and
all(isinstance(i, UnspecifiedValue) for i in params["in_shardings"]) and
all(isinstance(o, UnspecifiedValue) for o in params["out_shardings"]) and
all(i is None for i in params["in_layouts"]) and
all(o is None for o in params["out_layouts"])):
jaxpr = params["jaxpr"]
out = pe.inline_jaxpr_into_trace(
trace, source_info, jaxpr.jaxpr, jaxpr.consts, *args)
return [trace.to_jaxpr_tracer(x, source_info) for x in out]
jaxpr = params['jaxpr']
if any(isinstance(c, core.Ref) for c in jaxpr.consts):
jaxpr, consts = pxla._move_mutable_consts(jaxpr)
consts = [trace.new_const(c, source_info) for c in consts]
in_shardings = (*params['in_shardings'],) + (UNSPECIFIED,) * len(consts)
in_layouts = (*params['in_layouts'],) + (None,) * len(consts)
donated_invars = (*params['donated_invars'],) + (False,) * len(consts)
new_params = dict(params, jaxpr=jaxpr, in_shardings=in_shardings,
in_layouts=in_layouts, donated_invars=donated_invars)
out_tracers = trace.default_process_primitive(
jit_p, (*args, *consts), new_params, source_info=source_info)
else:
out_tracers = trace.default_process_primitive(
jit_p, args, params, source_info=source_info)
# TODO(mattjj): handle qdd in the presence of refs
for v, x in zip(it.chain(jaxpr.constvars, jaxpr.invars), it.chain(jaxpr.consts, args)):
if v.initial_qdd:
assert core.cur_qdd(x) == v.initial_qdd
x.aval_mutable_qdd.mutable_qdd.update(v.final_qdd)
return out_tracers
pe.custom_staging_rules[jit_p] = pjit_staging_rule
def pjit_forwarding_rule(eqn):
return [None] * len(eqn.outvars), eqn
# TODO(mattjj): Remove pjit_forwarding_rule and also in staging rule.
pe.forwarding_rules[jit_p] = pjit_forwarding_rule
def _pjit_typecheck(ctx_factory, *in_atoms, jaxpr, **params):
return core._check_call(ctx_factory, jit_p, in_atoms,
dict(params, call_jaxpr=jaxpr.jaxpr))
core.custom_typechecks[jit_p] = _pjit_typecheck
def _pjit_abstract_eval(*args, jaxpr, out_shardings, **_):
effs = core.eqn_effects(jaxpr) if jaxpr.constvars else jaxpr.effects
return jaxpr.out_avals, effs
jit_p.def_effectful_abstract_eval(_pjit_abstract_eval)
def _pjit_cached_lower_jaxpr_to_fun(
ctx: mlir.LoweringRuleContext, name: str, jaxpr: core.ClosedJaxpr,
num_const_args: int, in_avals, effects, in_shardings, out_shardings,
in_layouts, out_layouts, api_name):
assert len(in_avals) == num_const_args + len(jaxpr.in_avals)
assert len(in_avals) == len(in_shardings)
assert len(in_avals) == len(in_layouts)
mod_ctx = ctx.module_context
axis_ctx = ctx.module_context.axis_context
num_devices = None
if isinstance(axis_ctx, sharding_impls.ShardingContext):
num_devices = axis_ctx.num_devices
elif isinstance(axis_ctx, sharding_impls.SPMDAxisContext):
num_devices = axis_ctx.mesh.size
key = (jit_p, name, jaxpr, effects, num_devices,
pxla.SemanticallyEqualShardings(in_shardings, in_avals),
pxla.SemanticallyEqualShardings(out_shardings, jaxpr.out_avals),
in_layouts, out_layouts, api_name)
func = mod_ctx.cached_primitive_lowerings.get(key, None)
if func is None:
arg_shardings = [None if isinstance(i, UnspecifiedValue) else i
for i in in_shardings]
result_shardings = [None if isinstance(o, UnspecifiedValue) else o
for o in out_shardings]
# TODO(b/228598865): non-top-level functions cannot have shardings set
# directly on the inputs or outputs because they are lost during MLIR->HLO
# conversion. using_sharding_annotation=False means we add an identity
# operation instead.
func = mlir.lower_jaxpr_to_fun(
mod_ctx, name, jaxpr, effects,
num_const_args=num_const_args, in_avals=in_avals,
arg_shardings=arg_shardings, result_shardings=result_shardings,
use_sharding_annotations=False,
arg_layouts=in_layouts, result_layouts=out_layouts)
mod_ctx.cached_primitive_lowerings[key] = func
return func
def _pjit_lowering(ctx: mlir.LoweringRuleContext, *args, name: str,
jaxpr: core.ClosedJaxpr, in_shardings,
out_shardings, in_layouts, out_layouts, donated_invars,
ctx_mesh, keep_unused, inline, compiler_options_kvs):
effects = list(ctx.tokens_in.effects())
output_types = map(mlir._aval_to_ir_types, ctx.avals_out)
output_types = [mlir.token_type()] * len(effects) + output_types
flat_output_types = mlir.flatten_ir_types(output_types)
const_args_and_avals = core.jaxpr_const_args(jaxpr.jaxpr)
const_args, const_arg_avals = util.unzip2(const_args_and_avals)
in_avals = (*const_arg_avals, *jaxpr.in_avals)
ca_shardings = const_args_shardings(const_args)
in_shardings = ca_shardings + in_shardings
ca_layouts = const_args_layouts(const_args, const_arg_avals, ca_shardings)
in_layouts = ca_layouts + in_layouts
func = _pjit_cached_lower_jaxpr_to_fun(
ctx, name, jaxpr, len(const_args), in_avals, tuple(effects), in_shardings,
out_shardings, in_layouts, out_layouts, api_name='jit')
tokens_in = [ctx.tokens_in.get(eff) for eff in effects]
hoisted_const_values = mlir.flatten_ir_values(
mlir.ir_constants(c, const_lowering=ctx.const_lowering, aval=aval)
for c, aval in const_args_and_avals
)
args = (*ctx.dim_var_values, *tokens_in, *hoisted_const_values, *args)
with mlir.source_info_to_location(
ctx.module_context, None,
ctx.name_stack.extend(util.wrap_name('jit', name)), ctx.traceback):
call = func_dialect.CallOp(
flat_output_types, ir.FlatSymbolRefAttr.get(func.name.value),
mlir.flatten_ir_values(args))
mlir.wrap_compute_type_in_place(ctx, call) # pyrefly: ignore[bad-argument-type]
out_nodes = mlir.unflatten_ir_values_like_types(call.results, output_types)
tokens, out_nodes = split_list(out_nodes, [len(effects)])
tokens_out = ctx.tokens_in.update_tokens(mlir.TokenSet(zip(effects, tokens)))
ctx.set_tokens_out(tokens_out)
return out_nodes
# TODO(phawkins): this is marked uncacheable because it has its own cache and
# because the cache breaks jaxpr metadata like source locations. We should fix
# the metadata problem and consolidate the caches.
mlir.register_lowering(jit_p, _pjit_lowering, cacheable=False)
def const_args_shardings(const_args: Sequence[Array | np.ndarray]) -> Sequence[PjitSharding]:
const_args_types = map(convert_to_metaty, const_args)
return _resolve_in_shardings(
const_args_types, (sharding_impls.UNSPECIFIED,) * len(const_args))
def const_args_layouts(
const_args: Sequence[ArrayLike],
avals: Sequence[core.AbstractValue],
shardings: Sequence[PjitSharding]
) -> Sequence[Layout | AutoLayout | None]:
const_args_types = map(convert_to_metaty, const_args)
return _resolve_in_layouts(
const_args_types, (None,) * len(const_args), shardings, avals)
def _pjit_batcher(axis_data, vals_in,
dims_in: tuple[int, ...],
jaxpr: core.ClosedJaxpr,
in_shardings, out_shardings, in_layouts, out_layouts,
donated_invars, ctx_mesh, name, keep_unused, inline,
compiler_options_kvs):
new_jaxpr, axes_out = batching.batch_jaxpr2(jaxpr, axis_data, dims_in)
in_shardings = tuple(
_pjit_batcher_for_sharding(i, axis_in, axis_data.spmd_name, ctx_mesh,
aval.ndim)
if axis_in is not None else i
for axis_in, i, aval in zip(dims_in, in_shardings, new_jaxpr.in_avals))
out_shardings = tuple(
_pjit_batcher_for_sharding(o, axis_out, axis_data.spmd_name, ctx_mesh,
aval.ndim)
if axis_out is not None else o
for axis_out, o, aval in zip(axes_out, out_shardings, new_jaxpr.out_avals))
# TODO(yashkatariya): Figure out layouts should change under vmap.
if not (all(l is None for l in in_layouts) and
all(l is None for l in out_layouts)):
raise NotImplementedError(
'Concrete layouts are not supported for vmap(jit).')
vals_out = jit_p.bind(
*vals_in,
jaxpr=new_jaxpr,
in_shardings=in_shardings,
out_shardings=out_shardings,
in_layouts=in_layouts,
out_layouts=out_layouts,
donated_invars=donated_invars,
ctx_mesh=ctx_mesh,
name=name,
keep_unused=keep_unused,
inline=inline,
compiler_options_kvs=compiler_options_kvs)
return vals_out, axes_out
batching.fancy_primitive_batchers[jit_p] = _pjit_batcher
def _pjit_batcher_for_sharding(
s, dim: int, spmd_axis_name: tuple[str, ...] | None,
mesh, ndim: int):
if isinstance(s, UnspecifiedValue):
return s
hlo_s = s._to_xla_hlo_sharding(ndim)
if spmd_axis_name is None:
if sharding_impls.is_hlo_sharding_replicated(hlo_s):
return s
if isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh):
return NamedSharding(
s.mesh, pxla.batch_spec(s.spec, dim, PartitionSpec.UNCONSTRAINED))
new_op = hlo_s.to_proto().clone()
tad = list(new_op.tile_assignment_dimensions)
tad.insert(dim, 1)
new_op.tile_assignment_dimensions = tad
new_gs = GSPMDSharding(s._internal_device_list, new_op)
return pxla._get_out_sharding_from_orig_sharding([new_gs], [None], s, None)[0]
else:
if isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh):
return NamedSharding(
s.mesh, pxla.batch_spec(s.spec, dim, spmd_axis_name))
if isinstance(s, NamedSharding):
mesh = s.mesh
if mesh.empty or mesh.is_scalar:
raise ValueError(
'If you are using spmd_axis_name parameter of jax.vmap,'
' please make sure to run your jitted function inside the mesh'
' context manager. Only `jax.lax.with_sharding_constraint` with'
' `jax.sharding.NamedSharding` as an input can be transformed with'
' spmd_axis_name batching rules outside of an explicit mesh context'
f' manager scope {s!r}')
spec = parse_flatten_op_sharding(hlo_s, mesh)[0]
return NamedSharding(
mesh, pxla.batch_spec(spec, dim, spmd_axis_name))
def _pjit_jvp(primals_in, tangents_in,
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
donated_invars, ctx_mesh, name, keep_unused, inline,
compiler_options_kvs):
is_nz_tangents_in = [type(t) is not ad.Zero for t in tangents_in]
jaxpr_jvp, is_nz_tangents_out = ad.jvp_jaxpr(
jaxpr, is_nz_tangents_in, instantiate=False)
def _filter_zeros(is_nz_l, l):
return (x for nz, x in zip(is_nz_l, l) if nz)
_filter_zeros_in = partial(_filter_zeros, is_nz_tangents_in)
_filter_zeros_out = partial(_filter_zeros, is_nz_tangents_out)
outputs = jit_p.bind(
*primals_in, *_filter_zeros_in(tangents_in),
jaxpr=jaxpr_jvp,
in_shardings=(*in_shardings, *_filter_zeros_in(in_shardings)),
out_shardings=(*out_shardings, *_filter_zeros_out(out_shardings)),
in_layouts=(*in_layouts, *_filter_zeros_in(in_layouts)),
out_layouts=(*out_layouts, *_filter_zeros_out(out_layouts)),
donated_invars=(*donated_invars, *_filter_zeros_in(donated_invars)),
ctx_mesh=ctx_mesh,
name=name,
keep_unused=keep_unused,
inline=inline,
compiler_options_kvs=compiler_options_kvs)
primals_out, tangents_out = split_list(outputs, [len(jaxpr.jaxpr.outvars)])
assert len(primals_out) == len(jaxpr.jaxpr.outvars)
tangents_out_it = iter(tangents_out)
return primals_out, [next(tangents_out_it) if nz else ad.Zero(aval)
for nz, aval in zip(is_nz_tangents_out, jaxpr.out_avals)]
ad.primitive_jvps[jit_p] = _pjit_jvp
def _pjit_linearize(is_vjp, nzs, *primals_in, jaxpr, in_shardings, out_shardings,
in_layouts, out_layouts, donated_invars, ctx_mesh, name,
keep_unused, inline, compiler_options_kvs):
primal_jaxpr, num_residuals_out, nzs_out, in_fwd_res, tangent_jaxpr = \
ad.linearize_jaxpr(jaxpr, nzs, is_vjp=is_vjp)
num_residuals_in = len(in_fwd_res)
num_primals_out = len(primal_jaxpr.out_avals) - num_residuals_out
res_shardings_in = (UNSPECIFIED,) * num_residuals_in
res_layouts_in = (None,) * num_residuals_in
res_donated = (False,) * num_residuals_in
primal_out_shardings = tuple(out_shardings) + (UNSPECIFIED,) * num_residuals_out
primal_out_layouts = tuple(out_layouts) + (None,) * num_residuals_out
config.enable_checks.value and core.check_jaxpr(primal_jaxpr.jaxpr)
config.enable_checks.value and core.check_jaxpr(tangent_jaxpr.jaxpr)
def keep_where(l, should_keep):
return tuple(x for x, keep in zip(l, should_keep) if keep)
# Input-to-output forwarding.
in_fwd = pe._jaxpr_forwarding(primal_jaxpr.jaxpr)
in_fwd_primal, in_fwd_res_ = split_list(in_fwd, [num_primals_out])
assert all(f is None for f in in_fwd_res_)
in_fwd = [
fwd if isinstance(os, UnspecifiedValue) and ol is None else None
for os, ol, fwd in zip(out_shardings, out_layouts, in_fwd_primal)
] + in_fwd_res_
del in_fwd_res_, in_fwd_primal
keep = [f is None for f in in_fwd]
primal_jaxpr = pe.prune_closed_jaxpr_outputs(primal_jaxpr, keep)
primal_out_shardings = keep_where(primal_out_shardings, keep)
primal_out_layouts = keep_where(primal_out_layouts, keep)
_, kept_res = split_list(keep, [num_primals_out])
num_kept_residuals = sum(kept_res)
del keep, kept_res, num_primals_out
# Output-to-output forwarding.
num_primals_out = len(primal_jaxpr.out_avals) - num_kept_residuals
out_vars, res_vars = split_list(primal_jaxpr.jaxpr.outvars, [num_primals_out])
idx_map = {id(v): i for i, v in enumerate(out_vars)}
out_fwd = [None] * num_primals_out + [idx_map.get(id(v)) for v in res_vars]
keep = [f is None for f in out_fwd]
primal_jaxpr = pe.prune_closed_jaxpr_outputs(primal_jaxpr, keep)
primal_out_shardings = keep_where(primal_out_shardings, keep)
primal_out_layouts = keep_where(primal_out_layouts, keep)
del keep
tangent_avals_out = [a.to_tangent_aval() for a in jaxpr.out_avals]
def tangent_fun(residuals, *tangents):
tangents_nz = _filter_zeros(nzs, tangents)
nz_tangents_out = jit_p.bind(
*residuals, *tangents_nz, jaxpr=tangent_jaxpr,
in_shardings=res_shardings_in + _filter_zeros(nzs, in_shardings),
out_shardings=_filter_zeros(nzs_out, out_shardings),
in_layouts=res_layouts_in + _filter_zeros(nzs, in_layouts),
out_layouts=_filter_zeros(nzs_out, out_layouts),
donated_invars=res_donated + _filter_zeros(nzs, donated_invars),
ctx_mesh=ctx_mesh,
name=name,
keep_unused=keep_unused,
inline=inline,
compiler_options_kvs=compiler_options_kvs)
nz_tangents_out_ = iter(nz_tangents_out)
tangents_out = [next(nz_tangents_out_) if nz else ad.Zero(aval)
for (aval, nz) in zip(tangent_avals_out, nzs_out)]
return tangents_out
def _filter_zeros(is_nz_l, l):
return tuple(x for nz, x in zip(is_nz_l, l) if nz)
assert len(in_shardings) == len(primal_jaxpr.in_avals)
ans = jit_p.bind(*primals_in, jaxpr=primal_jaxpr,
in_shardings=in_shardings,
out_shardings=primal_out_shardings,
in_layouts=in_layouts,
out_layouts=primal_out_layouts,
donated_invars=donated_invars,
ctx_mesh=ctx_mesh,
name=name,
keep_unused=keep_unused,
inline=inline,
compiler_options_kvs=compiler_options_kvs)
ans = subs_list(out_fwd, ans, ans)
ans = subs_list(in_fwd, primals_in, ans)
primal_ans, residuals_ans = split_list(ans, [len(ans) - num_residuals_out])
residuals_ans = subs_list(in_fwd_res, [*jaxpr.consts, *primals_in], residuals_ans)
return primal_ans, nzs_out, residuals_ans, tangent_fun
ad.primitive_linearizations[jit_p] = _pjit_linearize
def _pjit_remat(policy, *args, jaxpr, **params):
jaxpr_fwd, jaxpr_rem, num_res = remat.remat_jaxpr(jaxpr, policy)
params_fwd, params_rem = _add_res_to_params(num_res, **params)
primals_res_out = jit_p.bind(*args, jaxpr=jaxpr_fwd, **params_fwd)
primals_out, res = split_list(primals_res_out, [len(jaxpr.outvars)])
return primals_out, partial(jit_p.bind, *res, jaxpr=jaxpr_rem, **params_rem)
remat.rules[jit_p] = _pjit_remat
def _add_res_to_params(num_res, in_shardings, out_shardings, in_layouts,
out_layouts, donated_invars, **params):
params_fwd = dict(params,
in_shardings=in_shardings,
out_shardings=out_shardings + (UNSPECIFIED,) * num_res,
in_layouts=in_layouts,
out_layouts=out_layouts + (None,) * num_res,
donated_invars=donated_invars)
params_rem = dict(params,
in_shardings=(UNSPECIFIED,) * num_res + in_shardings,
out_shardings=out_shardings,
in_layouts=(None,) * num_res + in_layouts,
out_layouts=out_layouts,
donated_invars=(False,) * num_res + donated_invars)
return params_fwd, params_rem
def _pjit_partial_eval(trace: pe.JaxprTrace,
*in_tracers,
jaxpr: core.ClosedJaxpr, in_shardings, out_shardings,
in_layouts, out_layouts, donated_invars, ctx_mesh,
name, keep_unused, inline, compiler_options_kvs):
in_pvals = [t.pval for t in in_tracers]
known_ins = tuple(pv.is_known() for pv in in_pvals)
unknown_ins = tuple(not k for k in known_ins)
known_jaxpr, unknown_jaxpr, unknown_outs, res_out_avals, in_fwd_res = \
pe.partial_eval_jaxpr_nounits_fwd(jaxpr, unknown_ins, instantiate=False)
unknown_outs = tuple(unknown_outs)
known_outs = tuple(not uk for uk in unknown_outs)
# out_shardings and out_layouts for residual values output by known_jaxpr
def keep_where(l, should_keep):
return tuple(x for x, keep in zip(l, should_keep) if keep)
known_out_shardings = (keep_where(out_shardings, known_outs)
+ (UNSPECIFIED,) * len(res_out_avals))
known_out_layouts = (keep_where(out_layouts, known_outs)
+ (None,) * len(res_out_avals))
# Input-to-output forwarding: compute which outputs are just forwarded inputs.
num_out_primals = len(known_jaxpr.out_avals) - len(res_out_avals)
in_fwd: list[int | None] = pe._jaxpr_forwarding(known_jaxpr.jaxpr)
in_fwd_primal, in_fwd_res_ = split_list(in_fwd, [num_out_primals])
assert all(f is None for f in in_fwd_res_)
in_fwd = [
fwd if isinstance(os, UnspecifiedValue) and ol is None else None
for os, ol, fwd in zip(
keep_where(out_shardings, known_outs),
keep_where(out_layouts, known_outs), in_fwd_primal)
] + in_fwd_res_
del in_fwd_primal, in_fwd_res_
# Prune jaxpr outputs and out_shardings by removing the input-forwards.
keep = [f is None for f in in_fwd]
known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
known_out_shardings = keep_where(known_out_shardings, keep)
known_out_layouts = keep_where(known_out_layouts, keep)
# Update num_out_primals to reflect pruning.
kept_primals, kept_res = split_list(keep, [num_out_primals])
num_out_primals = sum(kept_primals)
del keep, kept_primals, kept_res
# Output-to-output forwarding: compute which residuals are just primal outputs
out_vars, res_vars = split_list(known_jaxpr.jaxpr.outvars, [num_out_primals])
idx_map = {id(v): i for i, v in enumerate(out_vars)}
out_fwd = [None] * num_out_primals + [idx_map.get(id(v)) for v in res_vars]
# Prune jaxpr outputs and out_shardings by removing forwarded residuals.
keep = [f is None for f in out_fwd]
known_jaxpr = pe.prune_closed_jaxpr_outputs(known_jaxpr, keep)
known_out_shardings = keep_where(known_out_shardings, keep)
known_out_layouts = keep_where(known_out_layouts, keep)
del keep
known_params = dict(
jaxpr=known_jaxpr, in_shardings=keep_where(in_shardings, known_ins),
out_shardings=known_out_shardings,
in_layouts=keep_where(in_layouts, known_ins),
out_layouts=known_out_layouts,
donated_invars=keep_where(donated_invars, known_ins),
ctx_mesh=ctx_mesh,
name=name, keep_unused=keep_unused, inline=inline,
compiler_options_kvs=compiler_options_kvs)
assert len(known_params['out_shardings']) == len(known_params['jaxpr'].out_avals)
assert len(known_params['out_layouts']) == len(known_params['jaxpr'].out_avals)
# Bind known things to pjit_p.
known_inputs = [pv.get_known() for pv in in_pvals if pv.is_known()]
all_known_outs = jit_p.bind(*known_inputs, **known_params)
# Add back in the output fwds.
all_known_outs = subs_list(out_fwd, all_known_outs, all_known_outs)
# Add back in the input fwds.
all_known_outs = subs_list(in_fwd, known_inputs, all_known_outs)
known_out_vals, residual_vals = \
split_list(all_known_outs, [len(all_known_outs) - len(res_out_avals)])
residual_vals_ = iter(residual_vals)
residual_vals = [next(residual_vals_) if f is None
else [*jaxpr.consts, *known_inputs][f] for f in in_fwd_res]
assert next(residual_vals_, None) is None
residual_tracers = map(trace.new_instantiated_const, residual_vals)
# The convention of partial_eval_jaxpr_nounits is to place residual binders at
# the front of the jaxpr produced, so we move them to the back since both the
# jaxpr equation built below and the pjit transpose rule assume a
# residual-inputs-last convention.
unknown_jaxpr = pe.move_binders_to_back(
unknown_jaxpr, [True] * len(residual_vals) + [False] * sum(unknown_ins))
# Set up staged-out 'unknown' eqn
unknown_in_shardings = (keep_where(in_shardings, unknown_ins)
+ (UNSPECIFIED,) * len(residual_tracers))
unknown_in_layouts = (keep_where(in_layouts, unknown_ins)
+ (None,) * len(residual_tracers))
unknown_donated_invars = (keep_where(donated_invars, unknown_ins)
+ (False,) * len(residual_tracers))
unknown_params = dict(
jaxpr=unknown_jaxpr,
in_shardings=unknown_in_shardings,
in_layouts=unknown_in_layouts,
out_shardings=keep_where(out_shardings, unknown_outs),
out_layouts=keep_where(out_layouts, unknown_outs),
donated_invars=unknown_donated_invars,
ctx_mesh=ctx_mesh,
name=name,
keep_unused=keep_unused,
inline=inline,
compiler_options_kvs=compiler_options_kvs)
unknown_tracers_in = [t for t in in_tracers if not t.pval.is_known()]
unknown_out_avals = unknown_jaxpr.out_avals
unknown_tracers_out = [
pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
for aval in unknown_out_avals
]
unknown_tracers_in = [*unknown_tracers_in, *residual_tracers]
eqn = pe.new_eqn_recipe(trace, unknown_tracers_in,
unknown_tracers_out,
jit_p,
unknown_params,
unknown_jaxpr.effects,
source_info_util.current())
for t in unknown_tracers_out: t.recipe = eqn
if effects.partial_eval_kept_effects.filter_in(unknown_jaxpr.effects):
trace.effect_handles.append(pe.EffectHandle(unknown_tracers_in, eqn))
return merge_lists(unknown_outs, known_out_vals, unknown_tracers_out)
pe.custom_partial_eval_rules[jit_p] = _pjit_partial_eval
def _pjit_partial_eval_custom_params_updater(
unks_in: Sequence[bool], inst_in: Sequence[bool],
kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool],
num_res_out: int, num_res_in: int, params_known: dict, params_staged: dict
) -> tuple[dict, dict]:
# prune inputs to jaxpr_known according to unks_in
donated_invars_known, _ = pe.partition_list(unks_in, params_known['donated_invars'])
in_shardings_known, _ = pe.partition_list(unks_in, params_known['in_shardings'])
_, out_shardings_known = pe.partition_list(kept_outs_known, params_known['out_shardings'])
in_layouts_known, _ = pe.partition_list(unks_in, params_known['in_layouts'])
_, out_layouts_known = pe.partition_list(kept_outs_known, params_known['out_layouts'])
new_params_known = dict(params_known,
in_shardings=tuple(in_shardings_known),
out_shardings=(*out_shardings_known,
*[UNSPECIFIED] * num_res_out),
in_layouts=tuple(in_layouts_known),
out_layouts=(*out_layouts_known, *[None] * num_res_out),
donated_invars=tuple(donated_invars_known))
assert len(new_params_known['in_shardings']) == len(params_known['jaxpr'].in_avals)
assert len(new_params_known['out_shardings']) == len(params_known['jaxpr'].out_avals)
assert len(new_params_known['in_layouts']) == len(params_known['jaxpr'].in_avals)
assert len(new_params_known['out_layouts']) == len(params_known['jaxpr'].out_avals)
# added num_res new inputs to jaxpr_staged, and pruning according to inst_in
_, donated_invars_staged = pe.partition_list(inst_in, params_staged['donated_invars'])
donated_invars_staged = [False] * num_res_in + donated_invars_staged
_, in_shardings_staged = pe.partition_list(inst_in, params_staged['in_shardings'])
in_shardings_staged = [*[UNSPECIFIED] * num_res_in, *in_shardings_staged]
_, out_shardings_staged = pe.partition_list(kept_outs_staged, params_staged['out_shardings'])
_, in_layouts_staged = pe.partition_list(inst_in, params_staged['in_layouts'])
in_layouts_staged = [*[None] * num_res_in, *in_layouts_staged]
_, out_layouts_staged = pe.partition_list(kept_outs_staged, params_staged['out_layouts'])
new_params_staged = dict(params_staged,
in_shardings=tuple(in_shardings_staged),
out_shardings=tuple(out_shardings_staged),
in_layouts=tuple(in_layouts_staged),
out_layouts=tuple(out_layouts_staged),
donated_invars=tuple(donated_invars_staged))
assert len(new_params_staged['in_shardings']) == len(params_staged['jaxpr'].in_avals)
assert len(new_params_staged['out_shardings']) == len(params_staged['jaxpr'].out_avals)
assert len(new_params_staged['in_layouts']) == len(params_staged['jaxpr'].in_avals)
assert len(new_params_staged['out_layouts']) == len(params_staged['jaxpr'].out_avals)
return new_params_known, new_params_staged
pe.partial_eval_jaxpr_custom_rules[jit_p] = \
partial(pe.closed_call_partial_eval_custom_rule, 'jaxpr',
_pjit_partial_eval_custom_params_updater)
def _pjit_transpose_fancy(
cts_in, *args, jaxpr, in_shardings, out_shardings, in_layouts,
out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline,
compiler_options_kvs):
primals_ctrefs, specs = ad.project_accums(args)
in_flat, in_tree = tree_flatten((primals_ctrefs, cts_in))
in_avals = [core.AvalQDD(a, cur_qdd(x)) if (a := typeof(x)).has_qdd
else a for x in in_flat]
trans_jaxpr, out_tree = _transpose_jaxpr_fancy(jaxpr, in_tree, (*in_avals,), specs)
trans_in_shardings = (
[s for x, s in zip(args, in_shardings)
if not isinstance(x, (ad.ValAccum, ad.NullAccum))] +
[s for x, s in zip(cts_in, out_shardings) if not isinstance(x, ad.Zero)])
trans_in_layouts = (
[l for x, l in zip(args, in_layouts)
if not isinstance(x, (ad.ValAccum, ad.NullAccum))] +
[l for x, l in zip(cts_in, out_layouts) if not isinstance(x, ad.Zero)])
cts_out_ = tree_unflatten(out_tree, trans_jaxpr.out_avals)
trans_out_shardings = tuple(s for x, s in zip(cts_out_, in_shardings)
if isinstance(x, core.AbstractValue))
trans_out_layouts = tuple(l for x, l in zip(cts_out_, in_layouts )
if isinstance(x, core.AbstractValue))
try:
cts_out = jit_p.bind(
*in_flat, jaxpr=trans_jaxpr, in_shardings=tuple(trans_in_shardings),
in_layouts=tuple(trans_in_layouts), out_shardings=trans_out_shardings,
out_layouts=trans_out_layouts, donated_invars=(False,) * len(in_flat),
ctx_mesh=ctx_mesh, name=name, keep_unused=keep_unused, inline=inline,
compiler_options_kvs=compiler_options_kvs)
except api_util.InternalFloatingPointError as e:
print("Invalid nan value encountered in the backward pass of a jax.jit "
"function. Calling the de-optimized backward pass.")
try:
ad.backward_pass3(jaxpr.jaxpr, False, jaxpr.consts, args, cts_in)
except (FloatingPointError, ZeroDivisionError) as e2:
raise e2 from None # great
else:
# If control reaches this line, we got a NaN on the output of `compiled`
# but not `fun.call_wrapped` on the same arguments. Let's tell the user.
api_util._raise_no_nan_in_deoptimized(e)
# pyrefly: ignore[unbound-name] # pyrefly#2219
for x, ct in zip(args, tree_unflatten(out_tree, cts_out)):
if isinstance(x, ad.ValAccum): x.accum(ct)
@weakref_lru_cache
def _transpose_jaxpr_fancy(jaxpr, in_tree, in_avals, specs):
cell = lambda: None
def transposed(*in_flat):
primals_ctrefs, cts_in = tree_unflatten(in_tree, in_flat)
args = ad.unproject_accums(specs, primals_ctrefs)
ad.backward_pass3(jaxpr.jaxpr, False, jaxpr.consts, args, cts_in)
cts_out = [x.freeze() if isinstance(x, ad.ValAccum) else None for x in args]
cts_out, cell.out_tree = tree_flatten(cts_out) # pyrefly: ignore[missing-attribute]
return cts_out
dbg = jaxpr.jaxpr.debug_info.with_unknown_names()
trans_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(transposed, debug_info=dbg), in_avals)
return core.ClosedJaxpr(trans_jaxpr, consts), cell.out_tree # pyrefly: ignore[missing-attribute]
ad.fancy_transposes[jit_p] = _pjit_transpose_fancy
@weakref_lru_cache
def _dce_jaxpr_pjit(
jaxpr: core.ClosedJaxpr, used_outputs: tuple[bool, ...]
) -> tuple[core.ClosedJaxpr, list[bool]]:
new_jaxpr, used_inputs = pe.dce_jaxpr(jaxpr.jaxpr, used_outputs)
return core.ClosedJaxpr(new_jaxpr, jaxpr.consts), used_inputs
def dce_jaxpr_pjit_rule(used_outputs: list[bool], eqn: core.JaxprEqn
) -> tuple[list[bool], core.JaxprEqn | None]:
if not any(used_outputs) and not pe.has_effects(eqn):
return [False] * len(eqn.invars), None
dced_jaxpr, used_inputs = _dce_jaxpr_pjit(
eqn.params['jaxpr'], tuple(used_outputs))
def keep_where(xs, keeps):
return tuple(x for x, keep in zip(xs, keeps) if keep)
eqn_params = eqn.params
new_params = dict(
eqn_params,
jaxpr=dced_jaxpr,
in_shardings=keep_where(eqn_params["in_shardings"], used_inputs),
out_shardings=keep_where(eqn_params["out_shardings"], used_outputs),
in_layouts=keep_where(eqn_params["in_layouts"], used_inputs),
out_layouts=keep_where(eqn_params["out_layouts"], used_outputs),
donated_invars=keep_where(eqn_params["donated_invars"], used_inputs),
)
if not any(used_inputs) and not any(used_outputs) and not dced_jaxpr.effects:
return used_inputs, None
else:
new_effs = core.eqn_effects(dced_jaxpr)
new_eqn = core.new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, used_inputs) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, new_effs, eqn.source_info, eqn.ctx)
return used_inputs, new_eqn
pe.dce_rules[jit_p] = dce_jaxpr_pjit_rule
def _pjit_pp_rule(eqn: core.JaxprEqn,
context: core.JaxprPpContext,
settings: core.JaxprPpSettings) -> core.pp.Doc:
params = dict(eqn.params)
del params['inline']
if not any(params['donated_invars']):
del params['donated_invars']
if all(isinstance(s, UnspecifiedValue) for s in params['in_shardings']):
del params['in_shardings']
if all(isinstance(s, UnspecifiedValue) for s in params['out_shardings']):
del params['out_shardings']
if all(l is None for l in params['in_layouts']):
del params['in_layouts']
if all(l is None for l in params['out_layouts']):
del params['out_layouts']
if not params['keep_unused']:
del params['keep_unused']
if params['ctx_mesh'].empty:
del params['ctx_mesh']
if not params['compiler_options_kvs']:
del params['compiler_options_kvs']
if params['jaxpr'].jaxpr not in context.shared_jaxprs:
context.suggest_same_var_names(params['jaxpr'].jaxpr.invars, eqn.invars)
context.suggest_same_var_names(params['jaxpr'].jaxpr.outvars, eqn.outvars)
# Move name= to the front to make the resulting equation easier to scan.
del params["name"]
return core._pp_eqn(eqn, context, settings, params=["name"] + sorted(params))
core.pp_eqn_rules[jit_p] = _pjit_pp_rule
# -------------------- with_sharding_constraint --------------------
def check_shardings_are_auto(s: Sharding) -> None:
if not isinstance(s, NamedSharding):
return
mesh = s.mesh.abstract_mesh
if not all(mesh._name_to_type[i] == mesh_lib.AxisType.Auto
for axes in s.spec
if axes is not PartitionSpec.UNCONSTRAINED and axes is not None
for i in (axes if isinstance(axes, tuple) else (axes,))):
raise ValueError(
'The spec of NamedSharding passed to with_sharding_constraint can'
f' only refer to Auto axes of the mesh. Got spec={s.spec} and'
f' mesh={mesh}. You probably meant to use `reshard` API?')
def assert_shardings_equal(x_aval, user_sharding: NamedSharding):
x_spec = x_aval.sharding.spec
user_spec = user_sharding.spec._normalized_spec_for_aval(x_aval.ndim)
if config.remove_size_one_mesh_axis_from_type.value:
user_spec = core.remove_size_one_mesh_axis(user_spec, user_sharding.mesh)
for x, s in zip(x_spec, user_spec):
if s is PartitionSpec.UNCONSTRAINED:
continue
else:
if x != s:
raise AssertionError(
'`with_sharding_constraint` acts as an assert when all axes of'
f' mesh are of type `Explicit`. The array sharding: {x_spec} did'
f' not match the sharding provided: {user_spec}. Please use'
' `jax.sharding.reshard` to shard your input to the sharding you'
' want.')
def with_sharding_constraint(x, shardings):
"""Mechanism to constrain the sharding of an Array inside a jitted computation
This is a strict constraint for the GSPMD partitioner and not a hint. For examples
of how to use this function, see `Distributed arrays and automatic parallelization`_.
Inside of a jitted computation, with_sharding_constraint makes it possible to
constrain intermediate values to an uneven sharding. However, if such an
unevenly sharded value is output by the jitted computation, it will come out
as fully replicated, no matter the sharding annotation given.
Args:
x: PyTree of jax.Arrays which will have their shardings constrained
shardings: PyTree of sharding specifications. Valid values are the same as for
the ``in_shardings`` argument of :func:`jax.experimental.pjit`.
Returns:
x_with_shardings: PyTree of jax.Arrays with specified sharding constraints.
.. _Distributed arrays and automatic parallelization: https://docs.jax.dev/en/latest/parallel.html
"""
x_flat, tree = tree_flatten(x)
x_avals_flat = [core.shaped_abstractify(x) for x in x_flat]
layouts, shardings = _split_layout_and_sharding(shardings)
user_shardings = prepare_axis_resources(
shardings, "shardings", allow_unconstrained_dims=True)
del shardings
user_shardings_flat = tuple(
flatten_axes("with_sharding_constraint shardings", tree, user_shardings))
del user_shardings
user_layouts_flat = tuple(
flatten_axes("with_sharding_constraint layouts", tree, layouts))
del layouts
if not mesh_lib.get_concrete_mesh().empty:
context_mesh = mesh_lib.get_abstract_mesh()
elif not mesh_lib.get_abstract_mesh().empty:
context_mesh = mesh_lib.get_abstract_mesh()
else:
context_mesh = mesh_lib.thread_resources.env.physical_mesh
shardings_flat = [_create_sharding_for_array(context_mesh, a, 'shardings',
'with_sharding_constraint')
for a in user_shardings_flat]
for s, u in zip(shardings_flat, user_shardings_flat):
if isinstance(s, (UnspecifiedValue, AUTO)):
raise ValueError(
f'One of with_sharding_constraint arguments got sharding {u} which is'
' not allowed. Please only pass `jax.sharding.Sharding` instances.')
del user_shardings_flat
# TODO(bartchr): remove `unconstrained_dims` after migrating to Shardy. It's
# already part of the shardings.
unconstrained_dims = [get_unconstrained_dims(s)
if isinstance(s, NamedSharding) else frozenset()
for s in shardings_flat]
pjit_check_aval_sharding(
shardings_flat, x_avals_flat, ("",) * len(shardings_flat),
"with_sharding_constraint arguments",
allow_uneven_sharding=True)
check_aval_layout_compatibility(user_layouts_flat, x_avals_flat,
("",) * len(user_layouts_flat),
"with_sharding_constraint arguments")
outs = []
for xf, x_aval, s, l, ud in zip(x_flat, x_avals_flat, shardings_flat,
user_layouts_flat, unconstrained_dims):
if (mesh_lib.get_abstract_mesh().are_all_axes_explicit and l is None and
isinstance(s, NamedSharding)):
assert_shardings_equal(x_aval, s)
outs.append(xf)
else:
check_shardings_are_auto(s)
outs.append(sharding_constraint_p.bind(
xf, sharding=s, layout=l, context_mesh=context_mesh,
unconstrained_dims=ud))
return tree_unflatten(tree, outs)
def _identity_fn(x): return x
def _sharding_constraint_impl(x, sharding, layout, context_mesh,
unconstrained_dims):
if (isinstance(sharding, NamedSharding) and
isinstance(sharding.mesh, AbstractMesh)):
if (not context_mesh.empty and isinstance(context_mesh, AbstractMesh) and
not hasattr(x, 'sharding')):
concrete_mesh = mesh_lib.get_concrete_mesh()
assert not concrete_mesh.empty
sharding = NamedSharding(concrete_mesh, sharding.spec)
else:
aval = core.shaped_abstractify(x)
if not hasattr(x, 'sharding'):
raise ValueError(
'Target sharding contains a `jax.sharding.AbstractMesh` which'
' requires the input passed should be a `jax.Array`. Got'
f' {type(x)} with shape {aval.str_short()}')
if not isinstance(x.sharding, NamedSharding) or x.sharding.mesh.is_scalar: # pyrefly: ignore[missing-attribute]
raise TypeError(
'The sharding on the input must be a `NamedSharding` since the'
' target sharding has an `AbstractMesh` in it. Got sharding type'
f' {type(x.sharding)} for shape {aval.str_short()}')
if x.sharding.mesh.shape_tuple != sharding.mesh.shape_tuple:
raise ValueError(
f'Mesh shape of the input {x.sharding.mesh.shape_tuple} does not'
' match the mesh shape of the target sharding'
f' {sharding.mesh.shape_tuple} for shape {aval.str_short()}')
sharding = NamedSharding(x.sharding.mesh, sharding.spec)
if layout is None:
# Run a jit here to raise good errors when device assignment don't match.
return api.jit(_identity_fn, out_shardings=sharding)(x)
else:
return api.jit(_identity_fn, out_shardings=Format(layout, sharding))(x)
sharding_constraint_p = core.Primitive("sharding_constraint")
sharding_constraint_p.def_impl(_sharding_constraint_impl)
ad.deflinear2(sharding_constraint_p,
lambda ct, _, **params: (sharding_constraint_p.bind(ct, **params),))
def _sharding_constraint_abstract_eval(
x_aval, *, sharding, layout, context_mesh, unconstrained_dims):
if isinstance(sharding, NamedSharding):
return x_aval.update(
sharding=x_aval.sharding.update(mesh=sharding.mesh.abstract_mesh))
return x_aval.update(sharding=None)
sharding_constraint_p.def_abstract_eval(_sharding_constraint_abstract_eval)
def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding, layout,
context_mesh, unconstrained_dims):
in_aval, = ctx.avals_in
out_aval, = ctx.avals_out
axis_ctx = ctx.module_context.axis_context
if (isinstance(sharding, NamedSharding) and
any(o is not None for o in out_aval.sharding.spec)):
spec = sharding.spec._normalized_spec_for_aval(in_aval.ndim)
new_spec = []
for user_spec, aval_spec in zip(spec, out_aval.sharding.spec):
if aval_spec is None:
new_spec.append(user_spec)
else:
aval_spec = aval_spec if isinstance(aval_spec, tuple) else (aval_spec,)
if user_spec is PartitionSpec.UNCONSTRAINED:
raise NotImplementedError
if user_spec is None:
new_spec.append(aval_spec)
elif isinstance(user_spec, tuple):
new_spec.append(aval_spec + user_spec)
else:
new_spec.append(aval_spec + (user_spec,))
sharding = sharding.update(spec=new_spec)
if dtypes.issubdtype(in_aval.dtype, dtypes.extended):
in_aval = core.physical_aval(in_aval)
if (isinstance(axis_ctx, sharding_impls.SPMDAxisContext) and
axis_ctx.manual_axes):
sharding = mlir.add_manual_axes(axis_ctx, sharding, in_aval.ndim)
if config.use_shardy_partitioner.value:
sharding = sharding._to_sdy_sharding(in_aval.ndim)
else:
sharding = sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto()
out = mlir.wrap_with_sharding_op(
ctx, x_node, out_aval, sharding, unspecified_dims=unconstrained_dims)
if layout is not None:
out = mlir.wrap_with_layout_op(ctx, out, out_aval, layout, in_aval)
return [out]
mlir.register_lowering(sharding_constraint_p,
_sharding_constraint_hlo_lowering)
def _sharding_constraint_batcher(
axis_data, vals_in, dims_in, sharding, layout, context_mesh,
unconstrained_dims):
x, = vals_in
d, = dims_in
if d is None:
out = sharding_constraint_p.bind(
x, sharding=sharding, layout=layout, context_mesh=context_mesh,
unconstrained_dims=unconstrained_dims)
return out, None
if axis_data.spmd_name is not None and isinstance(sharding, NamedSharding):
used = {n for ns in sharding.spec
for n in (ns if isinstance(ns, tuple) else (ns,))}
if set(axis_data.spmd_name) & used:
raise ValueError(f"vmap spmd_axis_name {axis_data.spmd_name} cannot appear in "
"with_sharding_constraint spec, but got spec "
f"{sharding.spec}")
unconstrained_dims = {ud + (d <= ud) for ud in unconstrained_dims}
if axis_data.spmd_name is None:
unconstrained_dims.add(d)
vmapped_sharding = _pjit_batcher_for_sharding(
sharding, d, axis_data.spmd_name, context_mesh, x.ndim)
if unconstrained_dims and isinstance(vmapped_sharding, NamedSharding):
new_spec = list(vmapped_sharding.spec) + [None] * (x.ndim - len(vmapped_sharding.spec))
for u in unconstrained_dims:
new_spec[u] = PartitionSpec.UNCONSTRAINED
vmapped_sharding = NamedSharding(
vmapped_sharding.mesh, PartitionSpec(*new_spec))
vmapped_layout = (get_layout_for_vmap(d, layout) if layout is not None else
layout)
y = sharding_constraint_p.bind(
x,
sharding=vmapped_sharding,
layout=vmapped_layout,
context_mesh=context_mesh,
unconstrained_dims=frozenset(unconstrained_dims))
return y, d
batching.fancy_primitive_batchers[sharding_constraint_p] = _sharding_constraint_batcher
# -------------------- reshard ------------------------------------
def reshard(xs, out_shardings):
x_flat, treedef = tree_flatten(xs)
shardings_flat = flatten_axis_resources(
"reshard out_shardings", treedef, out_shardings, tupled_args=True)
x_avals_flat = [core.shaped_abstractify(x) for x in x_flat]
out_flat = []
for x, x_aval, s in safe_zip(x_flat, x_avals_flat, shardings_flat):
ds = canonicalize_sharding(s, 'reshard', check_mesh_consistency=False)
if ds is None:
raise ValueError(
'Reshard should only be used with out_shardings which are non-None '
f'and have a non-empty mesh. Got sharding {s}.'
)
ds = ds.update(spec=ds.spec._normalized_spec_for_aval(x_aval.ndim))
cmesh = (s.mesh if (isinstance(s, NamedSharding) and
isinstance(s.mesh, mesh_lib.Mesh))
else None)
out_flat.append(reshard_p.bind(x, dst_sharding=ds, concrete_mesh=cmesh))
return tree_unflatten(treedef, out_flat)
reshard_p = core.Primitive('reshard')
reshard_p.skip_canonicalization = True
def _reshard_abstract_eval(aval, *, dst_sharding, concrete_mesh):
assert isinstance(aval, core.ShapedArray)
if aval.sharding == dst_sharding:
return aval
return aval.update(sharding=dst_sharding)
reshard_p.def_abstract_eval(_reshard_abstract_eval)
def _reshard_impl(x, *, dst_sharding, concrete_mesh):
thunk = lambda: dispatch.apply_primitive(
reshard_p, x, dst_sharding=dst_sharding, concrete_mesh=concrete_mesh)
if concrete_mesh is None:
return thunk()
else:
with sharding_impls.set_mesh(concrete_mesh):
return thunk()
reshard_p.def_impl(_reshard_impl)
def _reshard_jvp_rule(primals, tangents, *, dst_sharding, concrete_mesh):
(p,), (t,) = primals, tangents
primal_out = reshard_p.bind(p, dst_sharding=dst_sharding,
concrete_mesh=concrete_mesh)
if type(t) is ad.Zero:
return primal_out, ad.p2tz(primal_out)
else:
tangent_out = reshard_p.bind(t, dst_sharding=dst_sharding,
concrete_mesh=concrete_mesh)
return primal_out, tangent_out
ad.primitive_jvps[reshard_p] = _reshard_jvp_rule
def _reshard_transpose_fancy(ct, x, *, dst_sharding, concrete_mesh):
assert isinstance(x, ad.GradAccum)
if type(ct) is ad.Zero or isinstance(x, ad.NullAccum):
return
out_sharding = x.aval.to_ct_aval().sharding # pyrefly: ignore[missing-attribute]
with mesh_lib.use_abstract_mesh(out_sharding.mesh):
x_bar = reshard_p.bind(ct, dst_sharding=out_sharding,
concrete_mesh=concrete_mesh)
x.accum(x_bar)
ad.fancy_transposes[reshard_p] = _reshard_transpose_fancy
def _reshard_hlo_lowering(ctx, x_node, *, dst_sharding, concrete_mesh):
aval_in, = ctx.avals_in
aval_out, = ctx.avals_out
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
aval_in = core.physical_aval(aval_in)
proto = (dst_sharding._to_sdy_sharding(aval_in.ndim)
if config.use_shardy_partitioner.value else
dst_sharding._to_xla_hlo_sharding(aval_in.ndim).to_proto())
return [mlir.lower_with_sharding_in_types(ctx, x_node, aval_out, proto)]
mlir.register_lowering(reshard_p, _reshard_hlo_lowering)
def _reshard_batcher(axis_data, vals_in, dims_in, dst_sharding, concrete_mesh):
x, = vals_in
d, = dims_in
if d is None:
out = reshard_p.bind(x, dst_sharding=dst_sharding,
concrete_mesh=concrete_mesh)
return out, None
vmapped_dst_sharding = batching.get_sharding_for_vmap(
axis_data, dst_sharding, d)
y = reshard_p.bind(x, dst_sharding=vmapped_dst_sharding,
concrete_mesh=concrete_mesh)
return y, d
batching.fancy_primitive_batchers[reshard_p] = _reshard_batcher
def _pp_reshard(eqn, ctx, settings):
return core._pp_eqn(eqn.replace(params={}), ctx, settings)
core.pp_eqn_rules[reshard_p] = _pp_reshard
# -------------------- Auto and Explicit mode -------------------------
@dataclass(frozen=True, kw_only=True)
class MeshInfo:
prev: AbstractMesh
new: AbstractMesh
axes: Any
def _get_new_mesh(axes: str | tuple[str, ...] | None,
axis_type: mesh_lib.AxisType, name: str, shardings=None
) -> MeshInfo | None:
cur_mesh = mesh_lib.get_abstract_mesh()
flat_shardings, _ = tree_flatten(shardings)
sharding_mesh = mesh_lib.empty_abstract_mesh
for i in flat_shardings:
if isinstance(i, NamedSharding):
if not sharding_mesh.empty and sharding_mesh != i.mesh.abstract_mesh:
raise ValueError(
f'Shardings passed to {name} should have the same mesh. Got one'
f' mesh {sharding_mesh} and another {i.mesh}')
sharding_mesh = i.mesh.abstract_mesh
if sharding_mesh.empty and cur_mesh.empty:
return None
if not sharding_mesh.empty and not cur_mesh.empty:
if sharding_mesh != cur_mesh:
raise ValueError(
f'Context mesh {cur_mesh} must match the mesh passed to shardings'
f' {sharding_mesh}. Recommended approach is to use'
' `jax.set_mesh` context manager.')
mesh_to_use = cur_mesh
elif sharding_mesh.empty and not cur_mesh.empty:
mesh_to_use = cur_mesh
else:
assert not sharding_mesh.empty and cur_mesh.empty
mesh_to_use = sharding_mesh
if axes is None:
axes = mesh_to_use.axis_names
if not isinstance(axes, tuple):
axes = (axes,)
for a in axes:
if (mesh_to_use._name_to_type[a] == mesh_lib.AxisType.Manual and
axis_type in {mesh_lib.AxisType.Auto, mesh_lib.AxisType.Explicit}):
raise NotImplementedError(
'Going from `Manual` AxisType to `Auto` or `Explicit` AxisType is not'
' allowed. Please file a bug at https://github.com/jax-ml/jax/issues'
' with your use case')
new_mesh = mesh_to_use.update_axis_types({a: axis_type for a in axes})
return MeshInfo(prev=mesh_to_use, new=new_mesh, axes=axes)
def auto_axes(f=None, /, *, axes: str | tuple[str, ...] | None = None,
out_sharding=None):
kwargs = dict(axes_=axes, out_sharding=out_sharding)
if f is None:
return lambda g: _auto_axes(g, **kwargs)
return _auto_axes(f, **kwargs)
def _auto_axes(fun, *, axes_, out_sharding):
@wraps(fun)
def decorator(*args, **kwargs):
if out_sharding is None:
if "out_sharding" in kwargs:
_out_sharding = kwargs.pop("out_sharding")
else:
raise TypeError("Missing required keyword argument: 'out_sharding'")
else:
_out_sharding = out_sharding
mesh_info = _get_new_mesh(
axes_, mesh_lib.AxisType.Auto, 'auto_axes', shardings=_out_sharding)
if mesh_info is None:
return fun(*args, **kwargs)
if set(mesh_info.prev.auto_axes) == set(mesh_info.axes):
return fun(*args, **kwargs)
with mesh_lib.use_abstract_mesh(mesh_info.new):
in_specs = tree_map(lambda a: core.modify_spec_for_auto_manual(
core.typeof(a).sharding.spec, mesh_info.new), args)
args = reshard(args, in_specs)
out = fun(*args, **kwargs)
return reshard(out, _out_sharding)
return decorator
def explicit_axes(f=None, /, *, axes: str | tuple[str, ...] | None = None,
in_sharding=None):
kwargs = dict(axes=axes, in_sharding=in_sharding)
if f is None:
return lambda g: _explicit_axes(g, **kwargs)
return _explicit_axes(f, **kwargs)
def _explicit_axes(fun, *, axes, in_sharding):
@wraps(fun)
def decorator(*args, **kwargs):
if in_sharding is None:
if "in_sharding" in kwargs:
_in_sharding = kwargs.pop("in_sharding")
else:
raise TypeError("Missing required keyword argument: 'in_sharding'")
else:
_in_sharding = in_sharding
mesh_info = _get_new_mesh(axes, mesh_lib.AxisType.Explicit, 'explicit_axes')
if mesh_info is None:
raise ValueError(
'Context mesh cannot be empty. Please use `jax.set_mesh` API to enter'
' into a mesh context when using `explicit_axes` API.')
with mesh_lib.use_abstract_mesh(mesh_info.new):
args = reshard(args, _in_sharding)
out = fun(*args, **kwargs)
out_specs = tree_map(lambda o: core.modify_spec_for_auto_manual(
core.typeof(o).sharding.spec, mesh_lib.get_abstract_mesh()), out)
return reshard(out, out_specs)
return decorator
# -------------------- with_layout_constraint --------------------
def with_layout_constraint(x, layouts):
x_flat, tree = tree_flatten(x)
x_avals_flat = [core.shaped_abstractify(x) for x in x_flat]
layouts_flat = tuple(flatten_axes("with_layout_constraint layouts", tree,
layouts))
if any(not isinstance(l, Layout) for l in layouts_flat):
raise ValueError(
'layouts passed to `with_layout_constraint` must be of type'
f' `Layout`. Got {[type(l) for l in layouts_flat]}')
check_aval_layout_compatibility(
layouts_flat, x_avals_flat, ("",) * len(layouts_flat),
"with_layout_constraint arguments")
outs = [layout_constraint_p.bind(xf, layout=l)
for xf, l in zip(x_flat, layouts_flat)]
return tree_unflatten(tree, outs)
layout_constraint_p = core.Primitive('layout_constraint')
layout_constraint_p.def_abstract_eval(lambda x, **_: x)
ad.deflinear2(layout_constraint_p,
lambda ct, _, **params: (layout_constraint_p.bind(ct, **params),))
def _layout_constraint_impl(x, *, layout):
if not isinstance(x, xc.ArrayImpl):
raise ValueError(
'with_layout_constraint in eager mode can only be applied to'
f' jax.Arrays. Got {type(x)}')
if x.format.layout == layout:
return x
return api.jit(_identity_fn, out_shardings=Format(layout, x.sharding))(x)
layout_constraint_p.def_impl(_layout_constraint_impl)
def _layout_constraint_hlo_lowering(ctx, x_node, *, layout):
aval, = ctx.avals_in
out_aval, = ctx.avals_out
return [mlir.wrap_with_layout_op(ctx, x_node, out_aval, layout, aval)]
mlir.register_lowering(layout_constraint_p,
_layout_constraint_hlo_lowering)
def _layout_constraint_batcher(axis_data, vals_in, dims_in, layout):
x, = vals_in
d, = dims_in
if d is None:
return layout_constraint_p.bind(x, layout=layout), None
vmapped_layout = get_layout_for_vmap(d, layout)
y = layout_constraint_p.bind(x, layout=vmapped_layout)
return y, d
batching.fancy_primitive_batchers[layout_constraint_p] = _layout_constraint_batcher
# -------------------- helpers --------------------
def get_unconstrained_dims(sharding: NamedSharding):
assert sharding.spec is not None
return frozenset(i for i, axes in enumerate(sharding.spec)
if axes is PartitionSpec.UNCONSTRAINED)