# Copyright 2018 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # This module contains utility functions split out of jax._src.lax.lax to # avoid cyclic dependencies. Definitions that are used at import time by # multiple modules can go here. from functools import partial import numpy as np from typing import cast from jax._src import core from jax._src import dispatch from jax._src import dtypes from jax._src import mesh as mesh_lib from jax._src import state from jax._src.named_sharding import DuplicateSpecError, NamedSharding from jax._src.partition_spec import PartitionSpec as P from jax._src.util import safe_zip from jax._src.typing import DimSize, DType, Shape zip, unsafe_zip = safe_zip, zip def input_dtype(x, *_, out_dtype=None, **__): if out_dtype is not None: return dtypes.canonicalize_dtype(out_dtype) return x.dtype def _argnum_weak_type(*argnums): return lambda *args, **_: all(args[i].weak_type for i in argnums) def standard_primitive(shape_rule, dtype_rule, name, weak_type_rule=None, sharding_rule=None, vma_rule=None, ur_rule=None, memory_space_rule=None): weak_type_rule = weak_type_rule or _standard_weak_type_rule prim = core.Primitive(name) prim.def_impl(partial(dispatch.apply_primitive, prim)) prim.def_abstract_eval( partial(standard_abstract_eval, prim, shape_rule, dtype_rule, weak_type_rule, sharding_rule, vma_rule, ur_rule, memory_space_rule)) return prim def _get_array_abstraction_level(a): return a.array_abstraction_level def _get_abstract_mesh_from_avals(in_avals) -> mesh_lib.AbstractMesh: m = None for a in in_avals: if a is core.abstract_token: continue if a.sharding.mesh.empty: continue if m is not None and m != a.sharding.mesh: if m.are_all_axes_auto and a.sharding.mesh.are_all_axes_auto: return mesh_lib.empty_abstract_mesh raise ValueError( f'Mesh for all inputs should be equal. Got one mesh: {m} and' f' another mesh: {a.sharding.mesh}') m = a.sharding.mesh return mesh_lib.empty_abstract_mesh if m is None else m def call_ur_rule(prim, ur_rule, out_s, num_out, *avals, **kwargs): if ur_rule is not None: return ur_rule(*avals, **kwargs) if any(a.sharding.spec.unreduced or a.sharding.spec.reduced for a in avals): raise NotImplementedError( f'unreduced/reduced rule for {prim.name} is not implemented. Please' ' file an issue at https://github.com/jax-ml/jax/issues') # Only handles explicit mode. No need to handle manual mode here. if any(s.spec.unreduced or s.spec.reduced for s in ([out_s] if num_out is None else out_s) if s is not None): raise NotImplementedError( f'unreduced/reduced rule for {prim.name} is not implemented. Please' ' file an issue at https://github.com/jax-ml/jax/issues') return ((frozenset(), frozenset()) if num_out is None else ([frozenset()] * num_out, [frozenset()] * num_out)) def call_sharding_rule(prim, sh_rule, ur_rule, num_out, *avals, **kwargs): cur_mesh = mesh_lib.get_abstract_mesh() aval_mesh = _get_abstract_mesh_from_avals(avals) if ((cur_mesh.empty or cur_mesh._are_all_axes_auto_or_manual) and (aval_mesh.empty or aval_mesh._are_all_axes_auto_or_manual)): aval_mesh = cur_mesh if aval_mesh.empty else aval_mesh out_s = NamedSharding(aval_mesh, P()) return out_s if num_out is None else [out_s] * num_out if sh_rule is None: raise core.ShardingTypeError( f'sharding rule for {prim.name} is not implemented. Please file an' ' issue at https://github.com/jax-ml/jax/issues. You can work around' ' this error by dropping that operation into full auto sharding' ' mode via: `jax.sharding.auto_axes(fun, out_shardings=...)`') out_s = sh_rule(*avals, **kwargs) unreduced, reduced = call_ur_rule(prim, ur_rule, out_s, num_out, *avals, **kwargs) up = lambda sh, u, r: sh.update(spec=sh.spec.update(unreduced=u, reduced=r)) return (up(out_s, unreduced, reduced) if num_out is None else [up(s, u, r) for s, u, r in zip(out_s, unreduced, reduced)]) def call_shape_dtype_sharding_rule( prim, shape_rule, dtype_rule, sharding_rule, ur_rule, multi_out, *avals, **kwargs): out_shapes = shape_rule(*avals, **kwargs) out_dtypes = dtype_rule(*avals, **kwargs) num_out = len(out_shapes) if multi_out else None try: out_shardings = call_sharding_rule(prim, sharding_rule, ur_rule, num_out, *avals, **kwargs) except DuplicateSpecError as e: if multi_out: raise avals_str = ', '.join(i.str_short(short_dtypes=True) for i in avals) mesh = mesh_lib.empty_abstract_mesh if e.mesh is None else e.mesh out_aval_str = core.str_short_aval( out_shapes, out_dtypes, mesh, e.pspec, core.empty_mat, core.MemorySpace.Device, short_dtypes=True) raise core.ShardingTypeError( f'{prim} operation with inputs: {avals_str} produces an illegally' f' sharded result: {out_aval_str}') from e return out_shapes, out_dtypes, out_shardings def _default_memory_space_rule(prim, *avals, **kwargs): if all(a.memory_space == core.MemorySpace.Any for a in avals): return core.MemorySpace.Any prev_aval = None for a in avals: if not a.ndim: continue if prev_aval is not None and prev_aval.memory_space != a.memory_space: raise ValueError( f'memory_space of all inputs passed to `{prim.name}` must be the' f' same. Got one operand with type: {prev_aval.str_short()} and' f' another operand with type: {a.str_short()}') prev_aval = a if prev_aval is None: return core.MemorySpace.Device return prev_aval.memory_space def multi_mem_space_rule(prim, num_out, *avals, **kwargs): out_mem_space = _default_memory_space_rule(prim, *avals, **kwargs) return [out_mem_space] * num_out def manual_rule(prim, vma_rule, ur_rule, multi_out, *avals, **kwargs): out_vma = vma_rule(*avals, **kwargs) num_out = len(out_vma) if multi_out else None if mesh_lib.get_abstract_mesh().are_all_axes_manual: out_s = None if num_out is None else [None] * num_out out_unreduced, out_reduced = call_ur_rule( prim, ur_rule, out_s, num_out, *avals, **kwargs) else: # TODO(yashkatariya): Handle partial manual unreduced/reduced. out_unreduced, out_reduced = ( (frozenset(), frozenset()) if num_out is None else ([frozenset()] * num_out, [frozenset()] * num_out)) if num_out is None: return core.ManualAxisType(varying=out_vma, unreduced=out_unreduced, reduced=out_reduced) else: return [core.ManualAxisType(varying=v, unreduced=u, reduced=r) for v, u, r in zip(out_vma, out_unreduced, out_reduced)] def standard_abstract_eval( prim, shape_rule, dtype_rule, weak_type_rule, sharding_rule, vma_rule, ur_rule, memory_space_rule, *avals, **kwargs): assert not prim.multiple_results for a in avals: if isinstance(a, state.AbstractRef): raise ValueError(f'Attempting to pass a Ref {a} to a primitive: ' f'{prim} -- did you forget to unpack ([...]) the ref?') if not isinstance(a, core.ShapedArray): raise ValueError(f'Attempting to pass an unexpected type {a} to a ' f'primitive: {prim}') weak_type = weak_type_rule(*avals, **kwargs) least_specialized = type(max(avals, key=_get_array_abstraction_level)) if least_specialized is core.ShapedArray: core.check_avals_context_mesh(avals, prim.name) out_shape, out_dtype, out_sharding = call_shape_dtype_sharding_rule( prim, shape_rule, dtype_rule, sharding_rule, ur_rule, False, *avals, **kwargs) out_mat = manual_rule(prim, vma_rule, ur_rule, False, *avals, **kwargs) out_mem_space = (_default_memory_space_rule(prim, *avals, **kwargs) if memory_space_rule is None else memory_space_rule(*avals, **kwargs)) out_aval = core.ShapedArray( out_shape, out_dtype, weak_type=weak_type, sharding=out_sharding, manual_axis_type=out_mat, memory_space=out_mem_space) core.check_avals_context_mesh([out_aval], prim.name) return out_aval else: raise TypeError(avals, least_specialized) def standard_multi_result_abstract_eval( prim, shape_rule, dtype_rule, weak_type_rule, sharding_rule, vma_rule, ur_rule, *avals, **kwargs): assert prim.multiple_results assert all(isinstance(aval, core.ShapedArray) for aval in avals), avals least_specialized = max(map(type, avals), key=_get_array_abstraction_level) weak_types = weak_type_rule(*avals, **kwargs) if least_specialized is core.ShapedArray: core.check_avals_context_mesh(avals, prim.name) out_shapes, out_dtypes, out_shardings = call_shape_dtype_sharding_rule( prim, shape_rule, dtype_rule, sharding_rule, ur_rule, True, *avals, **kwargs) out_mats = manual_rule(prim, vma_rule, ur_rule, True, *avals, **kwargs) out_mem_spaces = multi_mem_space_rule(prim, len(out_shapes), *avals, **kwargs) if isinstance(weak_types, bool): weak_types = (weak_types,) * len(out_shapes) out_avals = [core.ShapedArray(s, d, weak_type=weak_type, sharding=sh, manual_axis_type=mat, memory_space=ms) for s, d, weak_type, sh, mat, ms in zip( out_shapes, out_dtypes, weak_types, out_shardings, out_mats, out_mem_spaces)] core.check_avals_context_mesh(out_avals, prim.name) return out_avals else: raise TypeError(avals, least_specialized) def _standard_weak_type_rule(*avals, **kwargs): return all(aval.weak_type for aval in avals) def dtype_to_string(dtype): try: return str(np.dtype(dtype).name) except TypeError: pass try: return dtype.name except AttributeError: pass return str(dtype) _int32_max = np.iinfo(np.int32).max _uint32_max = np.iinfo(np.uint32).max def int_dtype_for_dim(d: DimSize, *, signed: bool) -> DType: """Returns a integer dtype large enough to contain indices in dimension d.""" if signed: if not core.is_constant_dim(d): return dtypes.default_int_dtype() return np.dtype(np.int64) if d > _int32_max else np.dtype(np.int32) else: if not core.is_constant_dim(d): return dtypes.default_uint_dtype() return np.dtype(np.uint64) if d > _uint32_max else np.dtype(np.uint32) def int_dtype_for_shape(shape: Shape, *, signed: bool) -> DType: """Returns a integer dtype large enough to contain indices in `shape`.""" if signed: for d in shape: if core.is_constant_dim(d): if d > _int32_max: return np.dtype(np.int64) else: return dtypes.default_int_dtype() return np.dtype(np.int32) else: for d in shape: if core.is_constant_dim(d): if d > _uint32_max: return np.dtype(np.uint64) else: return dtypes.default_uint_dtype() return np.dtype(np.uint32) def ensure_shaped(*avals: core.AbstractValue) -> tuple[core.ShapedArray | state.AbstractRef, ...]: """Cast all inputs to ShapedArray with a runtime instance check.""" if any(not isinstance(aval, (core.ShapedArray, state.AbstractRef)) for aval in avals): raise ValueError(f"Expected ShapedArray; got {[type(aval) for aval in avals]}") return tuple(cast(core.ShapedArray | state.AbstractRef, aval) for aval in avals)