This commit is contained in:
2026-05-06 19:47:31 +07:00
parent 94d8682530
commit 12dbb7731b
9963 changed files with 2747894 additions and 0 deletions
@@ -0,0 +1,13 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,907 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Serialization and deserialization of _export.Exported
from __future__ import annotations
from collections.abc import Callable, Iterable
import dataclasses
import itertools
from functools import partial
import types
from typing import cast, Any, TypeVar
try:
import flatbuffers
except ImportError as e:
raise ImportError(
"Please install 'flatbuffers' in order to use Exported serialization"
) from e
from jax._src import core
from jax._src import dtypes
from jax._src import effects
from jax._src.export import serialization_generated as ser_flatbuf
from jax._src.export import _export
from jax._src.export import shape_poly
from jax._src.lib import xla_client
from jax._src import mesh
from jax._src import named_sharding
from jax._src import partition_spec
from jax._src import tree_util
import numpy as np
T = TypeVar("T")
SerT = TypeVar("SerT")
# The _SERIALIZATION_VERSION changes when we change the serialization schema
# even if the change is backwards compatible.
# Version 1, Nov 2023, first version.
# Version 2, Dec 16th, 2023, adds the f0 dtype.
# Version 3, October 16th, 2024, adds serialization for namedtuple and custom types
# This version is backwards compatible with Version 2.
# Version 4, April 7th, 2025, adds serialization for PRNGs key types.
# This version is backwards compatible with Version 2 and 3.
# Version 5, November 23rd, 2025, adds serialization for aval memory_space,
# upgrade num_devices to a 32 bit value.
# This version is backwards compatible with Version 2 to 4.
# Version 6, January 15th, 2026, adds serialization for sharding as
# NamedSharding, including the abstract mesh, and the partition spec.
# Contains also HloSharding serialization, for forward compatibility.
# This version is backwards compatible with Version 2 to 5.
# Version 7 was briefly live but pulled back due to breaking compatiblity.
# Version 8, March 12th, 2026, add serializaton for AbstractMesh.abstract_device.
# This version is backwards compatible with Version 2 to 7.
# Version 9, March 17th, 2026, removes HloSharding serialization.
# This is another attempt at what Version 7 was supposed to be.
# This version is backwards compatible with Version 2 to 8.
# Version 10, April 4th, 2026, optimizes serialization of duplicate shardings,
# abstract meshes and avals.
_SERIALIZATION_VERSION = 10
@dataclasses.dataclass
class _SerializedUniques:
# Map unique objects to their index in the serialized data.
unique_avals: list[core.AbstractValue]
avals_map: dict[core.AbstractValue, int]
unique_abstract_meshes: list[mesh.AbstractMesh]
abstract_meshes_map: dict[mesh.AbstractMesh, int]
unique_named_shardings: list[named_sharding.NamedSharding]
named_shardings_map: dict[named_sharding.NamedSharding, int]
@staticmethod
def create_from_exported(exp: _export.Exported):
uniques = _SerializedUniques([], {}, [], {}, [], {})
for aval in itertools.chain(exp.in_avals, exp.out_avals):
uniques.add_aval(aval)
for sharding in itertools.chain(exp._in_named_shardings,
exp._out_named_shardings):
uniques.add_named_sharding(sharding)
return uniques
@staticmethod
def create_from_uniques(unique_avals: list[core.AbstractValue],
unique_abstract_meshes: list[mesh.AbstractMesh],
unique_named_shardings: list[named_sharding.NamedSharding]):
uniques = _SerializedUniques([], {}, [], {}, [], {})
uniques.unique_avals = unique_avals
uniques.avals_map = {a: i for i, a in enumerate(unique_avals)}
uniques.unique_abstract_meshes = unique_abstract_meshes
uniques.abstract_meshes_map = {m: i for i, m in enumerate(unique_abstract_meshes)}
uniques.unique_named_shardings = unique_named_shardings
uniques.named_shardings_map = {s: i for i, s in enumerate(unique_named_shardings)}
return uniques
def add_aval(self, aval: core.AbstractValue):
if aval not in self.avals_map:
self.avals_map[aval] = len(self.unique_avals)
self.unique_avals.append(aval)
def add_named_sharding(self, sharding: named_sharding.NamedSharding | None):
if sharding is None:
return
amesh = sharding.mesh.abstract_mesh
if amesh is not None and amesh not in self.abstract_meshes_map:
self.abstract_meshes_map[amesh] = len(self.unique_abstract_meshes)
self.unique_abstract_meshes.append(amesh)
if sharding not in self.named_shardings_map:
self.named_shardings_map[sharding] = len(self.unique_named_shardings)
self.unique_named_shardings.append(sharding)
def serialize(exp: _export.Exported, vjp_order: int = 0) -> bytearray:
"""Serializes an Exported.
Args:
exp: the Exported to serialize.
vjp_order: The maximum vjp order to include. E.g., the value 2 means that we
serialize the primal functions and two orders of the `vjp` function. This
should allow 2nd order reverse mode differentiation of the deserialized
function. i.e., `jax.grad(jax.grad(f)).`
"""
builder = flatbuffers.Builder(65536)
exported = _serialize_exported(builder, exp, vjp_order)
builder.Finish(exported)
return builder.Output()
def deserialize(ser: bytearray) -> _export.Exported:
"""Deserializes an Exported."""
exp = ser_flatbuf.Exported.GetRootAsExported(ser)
return _deserialize_exported(exp)
def _serialize_exported(
builder: flatbuffers.Builder, exp: _export.Exported, vjp_order: int
) -> int:
uniques = _SerializedUniques.create_from_exported(exp)
if not exp._has_named_shardings:
raise ValueError(
"Exported being serialized must have named shardings after 3/17/2026.")
# Serialize bottom-up
fun_name = builder.CreateString(exp.fun_name)
in_tree = _serialize_pytreedef(builder, exp.in_tree)
# TODO(necula): stop serializing in_avals 1 month after 4/4/26.
in_avals = _serialize_array(builder, _serialize_aval, exp.in_avals)
out_tree = _serialize_pytreedef(builder, exp.out_tree)
# TODO(necula): stop serializing out_avals 1 month after 4/4/26
out_avals = _serialize_array(builder, _serialize_aval, exp.out_avals)
# TODO(necula): stop serializing in_shardings 1 month after 4/4/26
in_shardings = _serialize_array(
builder, partial(_serialize_sharding, uniques=uniques),
exp._in_named_shardings)
# TODO(necula): stop serializing out_shardings 1 month after 4/4/26
out_shardings = _serialize_array(
builder, partial(_serialize_sharding, uniques=uniques),
exp._out_named_shardings)
ordered_effects = _serialize_array(
builder, _serialize_effect, exp.ordered_effects
)
unordered_effects = _serialize_array(
builder, _serialize_effect, exp.unordered_effects
)
disabled_safety_checks = _serialize_array(
builder, _serialize_disabled_safety_check, exp.disabled_safety_checks
)
platforms = _serialize_array(
builder, lambda b, p: b.CreateString(p), exp.platforms
)
mlir_module_serialized = builder.CreateByteVector(exp.mlir_module_serialized)
module_kept_var_idx = builder.CreateNumpyVector(
np.array(exp.module_kept_var_idx, dtype=np.uint16)
)
vjp = None
if vjp_order > 0:
if not exp.has_vjp():
# TODO: add test
raise ValueError(
"serialization of an Exported that does not have vjps of high-enough "
"order"
)
vjp = _serialize_exported(builder, exp.vjp(), vjp_order - 1)
unique_avals_offset = _serialize_array(
builder, _serialize_aval, uniques.unique_avals)
unique_abstract_meshes_offset = _serialize_array(
builder, _serialize_abstract_mesh, uniques.unique_abstract_meshes)
unique_named_shardings_offset = _serialize_array(
builder, partial(_serialize_named_sharding, uniques=uniques),
uniques.unique_named_shardings)
in_aval_idxs = builder.CreateNumpyVector(
np.array([uniques.avals_map[a] for a in exp.in_avals], dtype=np.uint32))
out_aval_idxs = builder.CreateNumpyVector(
np.array([uniques.avals_map[a] for a in exp.out_avals], dtype=np.uint32))
in_shardings_idxs = builder.CreateNumpyVector(
np.array([0 if s is None else 1 + uniques.named_shardings_map[s]
for s in exp._in_named_shardings], dtype=np.uint32))
out_shardings_idxs = builder.CreateNumpyVector(
np.array([0 if s is None else 1 + uniques.named_shardings_map[s]
for s in exp._out_named_shardings], dtype=np.uint32))
ser_flatbuf.ExportedStart(builder)
# TODO(necula): we cannot really store the actual serialization_version
# in the flatbuffer because prior to 11/25/2025 deserializers checked
# if the version is 2 or 3. I have now removed that check, but for the
# sake of old deserializers we can only store version 3. Starting
# on January 2026 we can store the actual version.
ser_flatbuf.ExportedAddSerializationVersion(builder, 3)
ser_flatbuf.ExportedAddFunctionName(builder, fun_name)
ser_flatbuf.ExportedAddInTree(builder, in_tree)
ser_flatbuf.ExportedAddInAvals(builder, in_avals)
ser_flatbuf.ExportedAddOutTree(builder, out_tree)
ser_flatbuf.ExportedAddOutAvals(builder, out_avals)
ser_flatbuf.ExportedAddNrDevices(builder, exp.nr_devices)
ser_flatbuf.ExportedAddInShardings(builder, in_shardings)
ser_flatbuf.ExportedAddOutShardings(builder, out_shardings)
ser_flatbuf.ExportedAddPlatforms(builder, platforms)
ser_flatbuf.ExportedAddOrderedEffects(builder, ordered_effects)
ser_flatbuf.ExportedAddUnorderedEffects(builder, unordered_effects)
ser_flatbuf.ExportedAddDisabledChecks(builder, disabled_safety_checks)
ser_flatbuf.ExportedAddMlirModuleSerialized(builder, mlir_module_serialized)
ser_flatbuf.ExportedAddCallingConventionVersion(
builder, exp.calling_convention_version
)
ser_flatbuf.ExportedAddModuleKeptVarIdx(builder, module_kept_var_idx)
ser_flatbuf.ExportedAddUsesGlobalConstants(
builder, exp.uses_global_constants
)
if vjp is not None:
ser_flatbuf.ExportedAddVjp(builder, vjp)
ser_flatbuf.ExportedAddUniqueAvals(builder, unique_avals_offset)
ser_flatbuf.ExportedAddUniqueAbstractMeshes(builder,
unique_abstract_meshes_offset)
ser_flatbuf.ExportedAddUniqueNamedShardings(builder,
unique_named_shardings_offset)
ser_flatbuf.ExportedAddInAvalsIdxs(builder, in_aval_idxs)
ser_flatbuf.ExportedAddOutAvalsIdxs(builder, out_aval_idxs)
ser_flatbuf.ExportedAddInShardingsIdxs(builder, in_shardings_idxs)
ser_flatbuf.ExportedAddOutShardingsIdxs(builder, out_shardings_idxs)
return ser_flatbuf.ExportedEnd(builder)
def _serialize_array(
builder: flatbuffers.Builder,
serialize_one: Callable[[flatbuffers.Builder, T], int],
elements: Iterable[T],
) -> int:
element_offsets = [serialize_one(builder, e) for e in elements]
del elements
ser_flatbuf.PyTreeDefStartChildrenVector(builder, len(element_offsets))
for sc in reversed(element_offsets):
builder.PrependUOffsetTRelative(sc)
return builder.EndVector()
def _deserialize_exported(exp: ser_flatbuf.Exported) -> _export.Exported:
scope = shape_poly.SymbolicScope(()) # TODO(necula): serialize the constraints
unique_avals = [
_deserialize_aval(exp.UniqueAvals(i), scope=scope, sharding=None)
for i in range(exp.UniqueAvalsLength())]
unique_abstract_meshes = [
_deserialize_abstract_mesh(exp.UniqueAbstractMeshes(i))
for i in range(exp.UniqueAbstractMeshesLength())
]
uniques = _SerializedUniques.create_from_uniques(unique_avals, # pyrefly: ignore[bad-argument-type]
unique_abstract_meshes,
[])
unique_named_shardings = [
_deserialize_named_sharding(exp.UniqueNamedShardings(i), uniques=uniques)
for i in range(exp.UniqueNamedShardingsLength())
]
uniques = _SerializedUniques.create_from_uniques(unique_avals, # pyrefly: ignore[bad-argument-type]
unique_abstract_meshes,
unique_named_shardings)
fun_name = exp.FunctionName().decode("utf-8")
_, in_tree = tree_util.tracing_registry.flatten(
_deserialize_pytreedef_to_pytree(exp.InTree())
)
_, out_tree = tree_util.tracing_registry.flatten(
_deserialize_pytreedef_to_pytree(exp.OutTree())
)
# TODO(necula): remove the fallback to NrDevicesShort and mark
# the field "deprecated" once we abandon the old
# serialization format (6 months after 11/24/2025).
nr_devices = exp.NrDevices() or exp.NrDevicesShort()
def sharding_by_idx(idx):
if idx == 0:
return None
return uniques.unique_named_shardings[idx - 1]
if exp.InShardingsIdxsLength() > 0:
in_shardings = tuple(
sharding_by_idx(exp.InShardingsIdxs(i))
for i in range(exp.InShardingsIdxsLength())
)
elif exp.InShardingsLength() > 0:
# TODO(necula): remove 6 months after 4/4/26
in_shardings = tuple(
_deserialize_sharding(exp.InShardings(i), uniques=uniques)
for i in range(exp.InShardingsLength())
)
else:
in_shardings = ()
if exp.OutShardingsIdxsLength() > 0:
out_shardings = tuple(
sharding_by_idx(exp.OutShardingsIdxs(i))
for i in range(exp.OutShardingsIdxsLength())
)
elif exp.OutShardingsLength() > 0:
# TODO(necula): remove 6 months after 4/4/26
out_shardings = tuple(
_deserialize_sharding(exp.OutShardings(i), uniques=uniques)
for i in range(exp.OutShardingsLength())
)
else:
out_shardings = ()
# has_named_sharding will be True for all exports created after 1/15/2026
# TODO(b/489569164): remove has_named_sharding 6 months after 1/15/2026
has_named_shardings = not any(isinstance(s, _export.HloSharding)
for s in itertools.chain(in_shardings, out_shardings))
if has_named_shardings:
def get_aval_by_idx(idx, sharding: _export.NamedSharding | None):
base_aval = uniques.unique_avals[idx]
if sharding is None:
return base_aval
return core.update_aval_with_sharding(base_aval, sharding)
if exp.InAvalsIdxsLength() > 0:
in_avals = tuple(
get_aval_by_idx(exp.InAvalsIdxs(i), in_shardings[i]) # pyrefly: ignore[bad-argument-type]
for i in range(exp.InAvalsIdxsLength()))
elif exp.InAvalsLength() > 0:
# TODO(necula): remove 6 months after 4/4/26
in_avals = tuple(
_deserialize_aval(exp.InAvals(i), scope=scope, sharding=in_shardings[i]) # pyrefly: ignore[bad-argument-type]
for i in range(exp.InAvalsLength()))
else:
in_avals = ()
if exp.OutAvalsIdxsLength() > 0:
out_avals = tuple(
get_aval_by_idx(exp.OutAvalsIdxs(i), out_shardings[i]) # pyrefly: ignore[bad-argument-type]
for i in range(exp.OutAvalsIdxsLength()))
elif exp.OutAvalsLength() > 0:
# TODO(necula): remove 6 months after 4/4/26
out_avals = tuple(
_deserialize_aval(exp.OutAvals(i), scope=scope, sharding=out_shardings[i]) # pyrefly: ignore[bad-argument-type]
for i in range(exp.OutAvalsLength())
)
else:
out_avals = ()
in_shardings_hlo = tuple(_export.named_to_hlo_sharding(s, aval) # pyrefly: ignore[bad-argument-type]
for s, aval in zip(in_shardings, in_avals))
out_shardings_hlo = tuple(_export.named_to_hlo_sharding(s, aval) # pyrefly: ignore[bad-argument-type]
for s, aval in zip(out_shardings, out_avals))
else:
# Export from before 1/15/26
in_avals = tuple(
_deserialize_aval(exp.InAvals(i), scope=scope, sharding=None)
for i in range(exp.InAvalsLength())
)
out_avals = tuple(
_deserialize_aval(exp.OutAvals(i), scope=scope, sharding=None)
for i in range(exp.OutAvalsLength())
)
in_shardings_hlo = cast(tuple[_export.HloSharding | None, ...], in_shardings)
in_shardings = (None,) * len(in_shardings)
out_shardings_hlo = cast(tuple[_export.HloSharding | None, ...], out_shardings)
out_shardings = (None,) * len(out_shardings)
platforms = tuple(
exp.Platforms(i).decode("utf-8")
for i in range(exp.PlatformsLength())
)
ordered_effects = tuple(
_deserialize_effect(exp.OrderedEffects(i))
for i in range(exp.OrderedEffectsLength())
)
unordered_effects = tuple(
_deserialize_effect(exp.UnorderedEffects(i))
for i in range(exp.UnorderedEffectsLength())
)
disabled_safety_checks = tuple(
_deserialize_disabled_safety_check(exp.DisabledChecks(i))
for i in range(exp.DisabledChecksLength())
)
mlir_module_serialized = exp.MlirModuleSerializedAsNumpy().tobytes()
calling_convention_version = exp.CallingConventionVersion()
module_kept_var_idx = tuple(exp.ModuleKeptVarIdxAsNumpy().tolist())
uses_global_constants = exp.UsesGlobalConstants()
_get_vjp = None
if vjp := exp.Vjp():
_get_vjp = lambda _: _deserialize_exported(vjp)
return _export.Exported(
fun_name=fun_name,
in_tree=in_tree,
in_avals=in_avals,
out_tree=out_tree,
out_avals=out_avals,
nr_devices=nr_devices,
in_shardings_hlo=in_shardings_hlo,
out_shardings_hlo=out_shardings_hlo,
_has_named_shardings=has_named_shardings,
_in_named_shardings=in_shardings, # pyrefly: ignore[bad-argument-type]
_out_named_shardings=out_shardings, # pyrefly: ignore[bad-argument-type]
platforms=platforms,
ordered_effects=ordered_effects,
unordered_effects=unordered_effects,
disabled_safety_checks=disabled_safety_checks,
mlir_module_serialized=mlir_module_serialized,
calling_convention_version=calling_convention_version,
module_kept_var_idx=module_kept_var_idx,
uses_global_constants=uses_global_constants,
_get_vjp=_get_vjp,
)
def _serialize_pytreedef(
builder: flatbuffers.Builder, p: tree_util.PyTreeDef
) -> int:
node_data = p.node_data()
children = p.children()
children_vector_offset = None
children_names_vector_offset = None
if children:
children_vector_offset = _serialize_array(
builder, _serialize_pytreedef, children
)
custom_name = None
custom_auxdata = None
node_type = node_data and node_data[0]
if node_data is None: # leaf
kind = ser_flatbuf.PyTreeDefKind.leaf
elif node_type is types.NoneType:
kind = ser_flatbuf.PyTreeDefKind.none
elif node_type is tuple:
kind = ser_flatbuf.PyTreeDefKind.tuple
elif node_type is list:
kind = ser_flatbuf.PyTreeDefKind.list
elif node_type is dict:
kind = ser_flatbuf.PyTreeDefKind.dict
assert len(node_data[1]) == len(children)
def serialize_key(builder, k):
if not isinstance(k, str):
raise TypeError(
"Serialization is supported only for dictionaries with string keys."
f" Found key {k} of type {type(k)}.")
return builder.CreateString(k)
children_names_vector_offset = _serialize_array(
builder, serialize_key, node_data[1]
)
elif node_type in _export.serialization_registry:
assert node_type is not None
kind = ser_flatbuf.PyTreeDefKind.custom
serialized_name, serialize_auxdata = _export.serialization_registry[node_type]
custom_name = builder.CreateString(serialized_name)
serialized_auxdata = serialize_auxdata(node_data[1])
if not isinstance(serialized_auxdata, (bytes, bytearray)):
raise ValueError(
"The custom serialization function for `node_type` must "
f"return a `bytes` object. It returned a {type(serialized_auxdata)}.")
custom_auxdata = builder.CreateByteVector(serialized_auxdata)
else:
raise ValueError(
"Cannot serialize PyTreeDef containing an "
f"unregistered type `{node_type}`. "
"Use `export.register_pytree_node_serialization` or "
"`export.register_namedtuple_serialization`.")
ser_flatbuf.PyTreeDefStart(builder)
ser_flatbuf.PyTreeDefAddKind(builder, kind)
if children_vector_offset:
ser_flatbuf.PyTreeDefAddChildren(builder, children_vector_offset)
if children_names_vector_offset:
ser_flatbuf.PyTreeDefAddChildrenNames(builder, children_names_vector_offset)
if custom_name is not None:
ser_flatbuf.PyTreeDefAddCustomName(builder, custom_name)
if custom_auxdata is not None:
ser_flatbuf.PyTreeDefAddCustomAuxdata(builder, custom_auxdata)
return ser_flatbuf.PyTreeDefEnd(builder)
def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef):
# We construct a PyTree and later we'll flatten it to get the PyTreeDef.
# TODO: is there a more direct way to construct a PyTreeDef?
kind = p.Kind()
nr_children = p.ChildrenLength()
children = [
_deserialize_pytreedef_to_pytree(p.Children(i))
for i in range(nr_children)
]
if kind == ser_flatbuf.PyTreeDefKind.leaf:
return 0.0
elif kind == ser_flatbuf.PyTreeDefKind.none:
return None
elif kind == ser_flatbuf.PyTreeDefKind.tuple:
return tuple(children)
elif kind == ser_flatbuf.PyTreeDefKind.list:
return list(children)
elif kind == ser_flatbuf.PyTreeDefKind.dict:
assert p.ChildrenNamesLength() == nr_children
keys = [p.ChildrenNames(i).decode("utf-8") for i in range(nr_children)]
return dict(zip(keys, children))
elif kind == ser_flatbuf.PyTreeDefKind.custom:
serialized_name = p.CustomName().decode("utf-8")
if serialized_name not in _export.deserialization_registry:
raise ValueError(
"Cannot deserialize a PyTreeDef containing an "
f"unregistered type `{serialized_name}`. "
"Use `export.register_pytree_node_serialization` or "
"`export.register_namedtuple_serialization`.")
nodetype, deserialize_auxdata, from_iter = _export.deserialization_registry[serialized_name]
auxdata = deserialize_auxdata(p.CustomAuxdataAsNumpy().tobytes())
return from_iter(auxdata, children)
else:
raise ValueError(
f"Cannot deserialize PyTreeDef with unknown kind: {kind}")
_dtype_to_dtype_kind = {
np.dtype("bool"): ser_flatbuf.DType.bool,
np.dtype("int8"): ser_flatbuf.DType.i8,
np.dtype("int16"): ser_flatbuf.DType.i16,
np.dtype("int32"): ser_flatbuf.DType.i32,
np.dtype("int64"): ser_flatbuf.DType.i64,
np.dtype("uint8"): ser_flatbuf.DType.ui8,
np.dtype("uint16"): ser_flatbuf.DType.ui16,
np.dtype("uint32"): ser_flatbuf.DType.ui32,
np.dtype("uint64"): ser_flatbuf.DType.ui64,
dtypes.float0: ser_flatbuf.DType.f0,
np.dtype("float16"): ser_flatbuf.DType.f16,
np.dtype("float32"): ser_flatbuf.DType.f32,
np.dtype("float64"): ser_flatbuf.DType.f64,
np.dtype("complex64"): ser_flatbuf.DType.c64,
np.dtype("complex128"): ser_flatbuf.DType.c128,
dtypes._bfloat16_dtype: ser_flatbuf.DType.bf16,
dtypes._int4_dtype: ser_flatbuf.DType.i4,
dtypes._uint4_dtype: ser_flatbuf.DType.ui4,
dtypes._float8_e4m3b11fnuz_dtype: ser_flatbuf.DType.f8_e4m3b11fnuz,
dtypes._float8_e4m3fn_dtype: ser_flatbuf.DType.f8_e4m3fn,
dtypes._float8_e4m3fnuz_dtype: ser_flatbuf.DType.f8_e4m3fnuz,
dtypes._float8_e5m2_dtype: ser_flatbuf.DType.f8_e5m2,
dtypes._float8_e5m2fnuz_dtype: ser_flatbuf.DType.f8_e5m2fnuz,
dtypes._float8_e3m4_dtype: ser_flatbuf.DType.f8_e3m4,
dtypes._float8_e4m3_dtype: ser_flatbuf.DType.f8_e4m3,
dtypes._float8_e8m0fnu_dtype: ser_flatbuf.DType.f8_e8m0fnu,
dtypes._float4_e2m1fn_dtype: ser_flatbuf.DType.f4_e2m1fn,
}
_dtype_kind_to_dtype = {
kind: dtype for dtype, kind in _dtype_to_dtype_kind.items()
}
def register_dtype_kind(dtype: Any, kind: int):
_dtype_to_dtype_kind[dtype] = kind
_dtype_kind_to_dtype[kind] = dtype
_memory_space_to_enum = {
core.MemorySpace.Device: ser_flatbuf.MemorySpace.Device,
core.MemorySpace.Host: ser_flatbuf.MemorySpace.Host,
core.MemorySpace.Any: ser_flatbuf.MemorySpace.Any,
}
_memory_space_from_enum = {v: k for k, v in _memory_space_to_enum.items()}
_axis_type_to_enum = {
core.AxisType.Auto: ser_flatbuf.AxisType.Auto,
core.AxisType.Explicit: ser_flatbuf.AxisType.Explicit,
core.AxisType.Manual: ser_flatbuf.AxisType.Manual,
}
_axis_type_from_enum = {v: k for k, v in _axis_type_to_enum.items()}
def _serialize_abstract_device(builder: flatbuffers.Builder,
device: mesh.AbstractDevice | None) -> int:
if device is None:
return 0
device_kind = builder.CreateString(device.device_kind)
ser_flatbuf.AbstractDeviceStart(builder)
ser_flatbuf.AbstractDeviceAddDeviceKind(builder, device_kind)
if device.num_cores is not None:
ser_flatbuf.AbstractDeviceAddNumCores(builder, device.num_cores)
return ser_flatbuf.AbstractDeviceEnd(builder)
def _deserialize_abstract_device(
ser_abs_device: ser_flatbuf.AbstractDevice | None
) -> mesh.AbstractDevice | None:
if ser_abs_device is None:
return None
device_kind = ser_abs_device.DeviceKind().decode("utf-8")
num_cores: int | None = ser_abs_device.NumCores()
return mesh.AbstractDevice(device_kind, num_cores)
def _serialize_abstract_mesh(builder: flatbuffers.Builder,
mesh: mesh.AbstractMesh) -> int:
ser_flatbuf.AbstractMeshStartAxisSizesVector(builder, len(mesh.axis_sizes))
for axis_size in reversed(mesh.axis_sizes):
builder.PrependUint32(axis_size)
axis_sizes = builder.EndVector()
axis_names = _serialize_array(builder,
lambda builder, an: builder.CreateString(an),
mesh.axis_names)
assert mesh.axis_types is not None, mesh
ser_flatbuf.AbstractMeshStartAxisTypesVector(builder, len(mesh.axis_types))
for axis_type in reversed(mesh.axis_types):
builder.PrependByte(_axis_type_to_enum[axis_type])
axis_types = builder.EndVector()
abstract_device = _serialize_abstract_device(builder, mesh.abstract_device)
ser_flatbuf.AbstractMeshStart(builder)
ser_flatbuf.AbstractMeshAddAxisSizes(builder, axis_sizes)
ser_flatbuf.AbstractMeshAddAxisNames(builder, axis_names)
ser_flatbuf.AbstractMeshAddAxisTypes(builder, axis_types)
if mesh.abstract_device is not None:
ser_flatbuf.AbstractMeshAddAbstractDevice(builder, abstract_device)
return ser_flatbuf.AbstractMeshEnd(builder)
def _deserialize_abstract_mesh(
ser_mesh: ser_flatbuf.AbstractMesh) -> mesh.AbstractMesh:
axis_sizes = tuple(ser_mesh.AxisSizes(i)
for i in range(ser_mesh.AxisSizesLength()))
axis_names = tuple(ser_mesh.AxisNames(i).decode("utf-8")
for i in range(ser_mesh.AxisNamesLength()))
axis_types = tuple(_axis_type_from_enum[ser_mesh.AxisTypes(i)]
for i in range(ser_mesh.AxisTypesLength()))
abstract_device = _deserialize_abstract_device(ser_mesh.AbstractDevice())
return mesh.AbstractMesh(axis_sizes, axis_names, axis_types,
abstract_device=abstract_device)
def _serialize_partition_spec_one_axis(builder: flatbuffers.Builder,
spec: str | tuple[str, ...] | None) -> int:
if spec is None:
axes = ()
else:
axes = (spec,) if isinstance(spec, str) else spec
axes_offset = _serialize_array(builder,
lambda builder, ps: builder.CreateString(ps),
axes)
ser_flatbuf.PartitionSpecOneAxisStart(builder)
ser_flatbuf.PartitionSpecOneAxisAddAxes(builder, axes_offset)
return ser_flatbuf.PartitionSpecOneAxisEnd(builder)
def _deserialize_partition_spec_one_axis(
spec: ser_flatbuf.PartitionSpecOneAxis) -> str | tuple[str, ...] | None:
axes = tuple(spec.Axes(i).decode("utf-8") for i in range(spec.AxesLength()))
if not axes:
return None
else:
return axes[0] if len(axes) == 1 else axes
def _serialize_partition_spec(builder: flatbuffers.Builder,
spec: partition_spec.PartitionSpec) -> int:
partitions = _serialize_array(builder, _serialize_partition_spec_one_axis,
spec._partitions) # pyrefly: ignore[bad-argument-type]
reduced = _serialize_array(builder,
lambda builder, ps: builder.CreateString(ps),
spec.reduced)
unreduced = _serialize_array(builder,
lambda builder, ps: builder.CreateString(ps),
spec.unreduced)
ser_flatbuf.PartitionSpecStart(builder)
ser_flatbuf.PartitionSpecAddPartitions(builder, partitions)
ser_flatbuf.PartitionSpecAddReduced(builder, reduced)
ser_flatbuf.PartitionSpecAddUnreduced(builder, unreduced)
return ser_flatbuf.PartitionSpecEnd(builder)
def _deserialize_partition_spec(spec: ser_flatbuf.PartitionSpec
) -> partition_spec.PartitionSpec:
partitions = tuple(_deserialize_partition_spec_one_axis(spec.Partitions(i))
for i in range(spec.PartitionsLength()))
reduced = frozenset(spec.Reduced(i).decode("utf-8")
for i in range(spec.ReducedLength()))
unreduced = frozenset(spec.Unreduced(i).decode("utf-8")
for i in range(spec.UnreducedLength()))
return partition_spec.PartitionSpec(*partitions,
reduced=reduced,
unreduced=unreduced)
def _serialize_named_sharding(
builder: flatbuffers.Builder, sharding: named_sharding.NamedSharding, *,
uniques: _SerializedUniques
) -> int:
abstract_mesh_idx = uniques.abstract_meshes_map[sharding.mesh.abstract_mesh]
# TODO(necula): 1 month after 4/4/26 we can stop serializing the full
# abstract_mesh and only serialize the index.
mesh_offset = _serialize_abstract_mesh(builder, sharding.mesh.abstract_mesh)
spec_offset = _serialize_partition_spec(builder, sharding.spec)
memory_kind = builder.CreateString(sharding.memory_kind) if sharding.memory_kind is not None else 0
ser_flatbuf.NamedShardingStart(builder)
ser_flatbuf.NamedShardingAddMesh(builder, mesh_offset)
ser_flatbuf.NamedShardingAddSpec(builder, spec_offset)
if memory_kind != 0:
ser_flatbuf.NamedShardingAddMemoryKind(builder, memory_kind)
ser_flatbuf.NamedShardingAddAbstractMeshIdx(builder, abstract_mesh_idx)
return ser_flatbuf.NamedShardingEnd(builder)
def _deserialize_named_sharding(
s: ser_flatbuf.NamedSharding, *, uniques: _SerializedUniques
) -> named_sharding.NamedSharding:
if uniques.unique_abstract_meshes:
amesh = uniques.unique_abstract_meshes[s.AbstractMeshIdx()]
else:
# TODO(necula): 6 months after 4/4/26 we can stop deserializing the full
# abstract_mesh.
amesh = _deserialize_abstract_mesh(s.Mesh())
spec = _deserialize_partition_spec(s.Spec())
memory_kind = s.MemoryKind().decode("utf-8") if s.MemoryKind() is not None else None
return named_sharding.NamedSharding(amesh, spec, memory_kind=memory_kind)
def _serialize_aval(
builder: flatbuffers.Builder, aval: core.ShapedArray
) -> int:
aval_kind = ser_flatbuf.AbstractValueKind.shapedArray
shape_offsets = [builder.CreateString(str(d)) for d in aval.shape]
ser_flatbuf.AbstractValueStartShapeVector(builder, len(aval.shape))
for d in reversed(shape_offsets):
builder.PrependUOffsetTRelative(d)
shape_vector_offset = builder.EndVector()
ser_flatbuf.AbstractValueStart(builder)
ser_flatbuf.AbstractValueAddKind(builder, aval_kind)
ser_flatbuf.AbstractValueAddShape(builder, shape_vector_offset)
ser_flatbuf.AbstractValueAddDtype(builder, _dtype_to_dtype_kind[aval.dtype])
ser_flatbuf.AbstractValueAddMemorySpace(builder, _memory_space_to_enum[aval.memory_space])
return ser_flatbuf.AbstractValueEnd(builder)
def _deserialize_aval(aval: ser_flatbuf.AbstractValue, *,
scope: shape_poly.SymbolicScope,
sharding: named_sharding.NamedSharding | None,
) -> core.ShapedArray:
dtype = _dtype_kind_to_dtype[aval.Dtype()]
shape = shape_poly.symbolic_shape(
",".join(
aval.Shape(i).decode("utf-8") for i in range(aval.ShapeLength())
),
scope=scope
)
if (ser_mem_space := aval.MemorySpace()):
mem_space = _memory_space_from_enum[ser_mem_space]
else:
mem_space = core.MemorySpace.Device
return core.update_aval_with_sharding(
core.ShapedArray(shape, dtype, memory_space=mem_space), sharding
)
def _serialize_sharding(
builder: flatbuffers.Builder, s: _export.NamedSharding | None, *,
uniques: _SerializedUniques) -> int:
named_sharding = None
if s is not None:
named_sharding = _serialize_named_sharding(builder, s, uniques=uniques)
ser_flatbuf.ShardingStart(builder)
if named_sharding is not None:
ser_flatbuf.ShardingAddNamedSharding(builder, named_sharding)
return ser_flatbuf.ShardingEnd(builder)
def _deserialize_sharding(s: ser_flatbuf.Sharding, *,
uniques: _SerializedUniques) -> _export.HloSharding | named_sharding.NamedSharding | None:
if (named_sharding_off := s.NamedSharding()) is not None:
# After 1/15/26 all exports will have named shardings (or None)
# TODO(necula): We must keep reading the NamedSharding for 6 months after 4/4/26
return _deserialize_named_sharding(named_sharding_off, uniques=uniques)
# TODO(b/489569164): We must keep reading the HloSharding for 6 months after 1/15/2026.
if not s.HloShardingProtoIsNone():
proto = xla_client.OpSharding()
proto.ParseFromString(s.HloShardingProtoAsNumpy().tobytes())
return xla_client.HloSharding.from_proto(proto)
return None # Unspecified sharding
def _serialize_effect(builder: flatbuffers.Builder, eff: core.Effect) -> int:
try:
eff_replica = eff.__class__()
except Exception:
raise NotImplementedError(
f"Effect {eff} must have a nullary constructor to be serializable"
)
try:
hash_eff = hash(eff)
hash_eff_replica = hash(eff_replica)
except Exception:
raise NotImplementedError(
f"Effect {eff} must be hashable to be serializable"
)
if eff != eff_replica or hash_eff != hash_eff_replica:
raise NotImplementedError(
f"Effect {eff} must have a nullary class constructor that produces an "
"equal effect object."
)
effect_type_name = str(eff.__class__)
effect_type_name_offset = builder.CreateString(effect_type_name)
ser_flatbuf.EffectStart(builder)
ser_flatbuf.EffectAddTypeName(builder, effect_type_name_offset)
return ser_flatbuf.ExportedEnd(builder)
def _deserialize_effect(eff: ser_flatbuf.Effect) -> core.Effect:
effect_type_name = eff.TypeName().decode("utf-8")
for existing_effect_type in effects.lowerable_effects._effect_types:
if str(existing_effect_type) == effect_type_name:
try:
return existing_effect_type()
except:
# TODO: add test
raise NotImplementedError(
f"deserializing effect {effect_type_name} that does not have a "
"nullary class constructor"
)
raise NotImplementedError(
f"cannot deserialize effect type {effect_type_name}"
)
def _serialize_disabled_safety_check(
builder: flatbuffers.Builder, check: _export.DisabledSafetyCheck
) -> int:
custom_call_target_str = check.is_custom_call()
custom_call_target = None
if custom_call_target_str is not None:
kind = ser_flatbuf.DisabledSafetyCheckKind.custom_call
custom_call_target = builder.CreateString(custom_call_target_str)
elif check == _export.DisabledSafetyCheck.platform():
kind = ser_flatbuf.DisabledSafetyCheckKind.platform
else:
raise NotImplementedError(f"serializing DisabledSafetyCheck: {check}")
ser_flatbuf.DisabledSafetyCheckStart(builder)
ser_flatbuf.DisabledSafetyCheckAddKind(builder, kind)
if custom_call_target is not None:
ser_flatbuf.DisabledSafetyCheckAddCustomCallTarget(
builder, custom_call_target
)
return ser_flatbuf.DisabledSafetyCheckEnd(builder)
def _deserialize_disabled_safety_check(
sc: ser_flatbuf.DisabledSafetyCheck,
) -> _export.DisabledSafetyCheck:
kind = sc.Kind()
if kind == ser_flatbuf.DisabledSafetyCheckKind.custom_call:
return _export.DisabledSafetyCheck.custom_call(
sc.CustomCallTarget().decode("utf-8")
)
if kind == ser_flatbuf.DisabledSafetyCheckKind.platform:
return _export.DisabledSafetyCheck.platform()
raise ValueError(f"Cannot deserialize DisabledSafetyCheck with unknown kind: {kind}")
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,470 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shape polymorphism support for deciding inequalities of symbolic dimensions.
"""
from __future__ import annotations
from collections.abc import Sequence
import itertools
import math
import numpy as np
from jax._src import core
from jax._src.export import shape_poly
from jax._src.export.shape_poly import (
_DimExpr, _DimTerm, _DimFactor,
SymbolicScope,
DimSize,
InconclusiveDimensionOperation,
Comparator,
BoundsPrecision,
)
def sgn(x): return 1 if x >= 0 else -1
def bounds_decision(e: DimSize,
prec: BoundsPrecision) -> tuple[float, float]:
if not isinstance(e, _DimExpr):
return (int(e), int(e))
decision = _DecisionByElimination.build(e.scope)
return decision.bounds(e, prec, add_implicit_constraints=True)
shape_poly._bounds_decision = bounds_decision
class _DecisionByElimination:
"""A decision procedure based on elimination of terms.
Given an expression `e = t*t_k + rest_e` for which we want to compute bounds,
and a constraint `c = t*t_c_k + rest_c >= 0`,
Let `e0 = e*abs(t_c_k) - c*sgn(t_c_k)*t_k`. (Note that we eliminated `t` from
`e0`, since `abs(t_c_k)*t_k = sgn(t_c_k)*t_k*t_c_k`.)
Since `c >= 0`,
if `sgn(t_c_k)*t_k > 0`:
then `abs(t_c_k)*e >= e0`, hence, `LB(e) >= ceil(LB(e0) / abs(t_c_k))`,
if `sgn(t_c_k)*t_k < 0`
then `abs(t_c_k)*e <= e0`, hence, `UB(e) <= floor(UB(e0) / abs(t_c_k))`,
See the implementation in self.combine_term_with_existing.
Do not use the constructor directly, use the `build` static method.
"""
def __init__(self, scope: SymbolicScope):
self.scope = scope
self._processed_for_internal_constraints: set[_DimTerm] = set()
# The other fields are for keeping an efficient representation of
# the explicit constraints.
self._term_bounds: dict[_DimTerm, tuple[float, float]] = {}
# The _expr_constraints represents a set of constraints that are not
# just simple terms. The set is represented as a mapping from a
# term "t" to tuples (cmp, k, c) where "c >= 0" (if cmp is GEQ else "c == 0")
# represents a constraint that has "t" as the leading term with coefficient "k".
self._expr_constraints: dict[_DimTerm, set[tuple[Comparator, int, _DimExpr]]] = {}
def initialize(self) -> _DecisionByElimination:
# Process the explicit constraints in the order in which the user specifies
# them. This is because the heuristics depend on the order in which the
# constraints are processed, and this way we give the user a way to control
# the result (albeit, for now, without a good feedback loop to understand
# how the order matters for inequalities).
for constr in self.scope._explicit_constraints:
if not core.is_constant_dim(constr.diff):
self.add_implicit_constraints_expr(constr.diff) # pyrefly: ignore[bad-argument-type]
self.combine_and_add_constraint(constr.cmp, constr.diff, 0,
constr.debug_str)
# Clear the cache, since we have added constraints.
self.scope._bounds_cache.clear()
return self
@staticmethod
def build(scope: SymbolicScope) -> _DecisionByElimination:
"""Builds an initialized DecisionByElimination for a scope.
Caches the initial state of the decision procedure in the scope.
"""
if not scope._initialized or not scope._explicit_constraints:
# We do not cache until the scope is fully initialized.
return _DecisionByElimination(scope).initialize()
if not scope._decision_initial_state:
scope._decision_initial_state = _DecisionByElimination(scope).initialize()
d = scope._decision_initial_state
# Return a copy, because the decision procedure state is mutable
c = _DecisionByElimination(scope)
c._processed_for_internal_constraints = d._processed_for_internal_constraints.copy()
c._term_bounds = d._term_bounds.copy()
c._expr_constraints = {
lead_t: lead_t_constraints.copy()
for lead_t, lead_t_constraints in d._expr_constraints.items()}
return c
def combine_and_add_constraint(self,
cmp: Comparator,
e1: _DimExpr | int | float,
e2: _DimExpr | int | float,
debug_str: str | None = None):
"""Adds a constraint "e1 >= e2" to the internal state."""
if isinstance(e1, float):
if np.isinf(e1) and e1 >= 0 and cmp == Comparator.GEQ: return
assert e1 == np.floor(e1)
e1 = int(e1)
if isinstance(e2, float):
if np.isinf(e2) and e2 <= 0 and cmp == Comparator.GEQ: return
assert e2 == np.floor(e2)
e2 = int(e2)
e = e1 - e2
if (const := _DimExpr._to_constant(e)) is not None:
if const < 0:
raise ValueError(f"Unsatisfiable constraint: {debug_str or str(e1) + ' >= ' + str(e2)}")
return
assert isinstance(e, _DimExpr)
self.add_to_state(cmp, e, debug_str)
geq_combinations = self.combine_constraint_with_existing(cmp, e, debug_str)
for cmp, a in geq_combinations:
self.add_to_state(cmp, a, None)
def add_to_state(self,
cmp: Comparator,
e: _DimExpr,
debug_str: str | None):
"""Updates the internal state to reflect "e >= 0". """
assert _DimExpr._to_constant(e) is None
if (term_factors := e._to_single_term()) is not None:
n, t_k, t = term_factors # n + t * t_k [== | >=] 0
lb, ub = self._term_bounds.get(t, (- np.inf, np.inf))
if cmp == Comparator.EQ:
# n + t_k * t == 0 -> t == - n // t_k
if n % t_k:
raise ValueError(f"Unsatisfiable constraint: {debug_str}")
t_val = - (n // t_k)
lb = max(lb, t_val)
ub = min(ub, t_val)
else: # GEQ
if t_k > 0:
lb = max(lb, int(np.ceil(- n / t_k)))
else:
ub = min(ub, int(np.floor(- n / t_k)))
if lb > ub:
raise ValueError(f"Unsatisfiable constraint: {debug_str}")
self._term_bounds[t] = (lb, ub)
return
lead_t, lead_t_k = e._leading_term
lead_t_constraints = self._expr_constraints.get(lead_t)
if lead_t_constraints is None:
lead_t_constraints = set()
self._expr_constraints[lead_t] = lead_t_constraints
lead_t_constraints.add((cmp, lead_t_k, e))
def combine_term_with_existing(self, t: _DimTerm, t_k: int, *,
scope: shape_poly.SymbolicScope,
only_smaller_than_t=True,
) -> Sequence[tuple[Comparator,
_DimExpr,
int,
int]]:
"""
Combine a term with existing constraints.
For input (t, t_k) the tuple (c_eq, c, c_s, t_s) is among the returned
tuples if there exists a constraint `c =[c_eq] 0` that can be combined
with `t*t_k` to eliminate `t`, and:
* `c =[c_eq] 0`
* The term `comb = t*t_k*t_s + c*c_s` does not contain `t`, and if
`only_smaller_than_t` then `comb` contains only terms structurally
smaller than `t`.
* `c_s > 0`
"""
# TODO: maybe a generator is useful here instead of materializing the list
acc: list[tuple[Comparator, _DimExpr, int, int]] = []
# First combine with the existing term bounds
t_lb, t_ub = self._term_bounds.get(t, (-np.inf, np.inf))
if t_lb == t_ub:
acc.append((Comparator.EQ, _DimExpr(((t, 1),), scope) - int(t_lb),
abs(t_k), - sgn(t_k)))
else:
if t_lb > -np.inf:
acc.append((Comparator.GEQ, _DimExpr(((t, 1),), scope) - int(t_lb),
abs(t_k), - sgn(t_k)))
if t_ub < np.inf:
acc.append((Comparator.GEQ, _DimExpr(((t, -1),), scope) + int(t_ub),
abs(t_k), sgn(t_k)))
prev_constraint: set[tuple[Comparator, int, _DimExpr]]
for prev_constraint in ([self._expr_constraints.get(t, set())] if only_smaller_than_t
else self._expr_constraints.values()):
for c_eq, _, c in prev_constraint:
# TODO: optimize this dict()
tc_k = dict(c._sorted_terms).get(t)
if tc_k is not None:
# c =[c_eq] 0 AND t*tc_k appears in c.
c_s = abs(t_k)
c_t = - tc_k * sgn(t_k)
acc.append((c_eq, c, c_s, c_t))
return acc
def combine_constraint_with_existing(self,
eq: Comparator,
e: _DimExpr,
debug_str: str | None) -> set[tuple[Comparator, _DimExpr]]:
combinations: set[tuple[Comparator, _DimExpr]] = set()
for t, t_k in e._sorted_terms:
if t.is_constant: continue
for (c_eq, c, c_s, t_s) in self.combine_term_with_existing(t, t_k,
only_smaller_than_t=False,
scope=e.scope):
# c =[c_eq] 0 AND c_s > 0 AND t*t_k*t_s + c*c_s does not contain t
if t_s > 0 or eq == Comparator.EQ:
new_eq = Comparator.EQ if (eq == c_eq == Comparator.EQ) else Comparator.GEQ
new_e = _DimExpr._linear_combination(e, t_s, c, c_s, e.scope)
if (const := _DimExpr._to_constant(new_e)) is not None:
if ((new_eq == Comparator.GEQ and const < 0) or
(new_eq == Comparator.EQ and const != 0)):
raise ValueError(f"Unsatisfiable constraints: {debug_str or str(e) + ' >= 0'}")
else:
combinations.add((new_eq, new_e)) # pyrefly: ignore[bad-argument-type]
return combinations
def bounds(self, e: DimSize,
prec: BoundsPrecision,
add_implicit_constraints: bool = False
) -> tuple[float, float]:
"""Returns the lower and upper bounds, or -+inf.
Args:
e: the expression for which to compute the bounds.
prec: the desired precision. See comments in `BoundsPrecision`.
add_implicit_constraints: if True, then before computing the bounds
add the implicit constraints for the terms inside `e`.
"""
if (const := _DimExpr._to_constant(e)) is not None:
return (const, const)
assert isinstance(e, _DimExpr)
# Caching bounds is tricky. Since the underlying _bounds_for_sorted_terms
# is incomplete, and it may produce better results in the context of
# specific queries (due to the implicit constraints), if we cache the
# bounds computation we may stick to sub-optimal results. Also, we should
# not use the precision as part of the cache key, because a certain result
# may work for multiple precisions.
if (res := self.scope._bounds_cache.get(e)) is not None:
lb, ub, prev_prec = res
if prec._bounds_are_sufficient(lb, ub): return (lb, ub)
if prev_prec.value >= prec.value: return (lb, ub)
if add_implicit_constraints:
self.add_implicit_constraints_expr(e)
lb, ub = self._bounds_for_sorted_terms(e.scope, e._sorted_terms, 0, prec)
lb, ub = (int(lb) if lb > -np.inf else lb,
int(ub) if ub < np.inf else ub)
self.scope._bounds_cache[e] = (lb, ub, prec)
return (lb, ub)
def _bounds_for_sorted_terms(self,
scope: SymbolicScope,
e: Sequence[tuple[_DimTerm, int]],
i: int,
prec: BoundsPrecision) -> tuple[float, float]:
"""The lower and upper bounds of e[i:].
See comments about soundness and `cmp_with` in the `shape_poly.bounds_decision`` method.
Returns (lower-bound, upper-bound)
"""
if i >= len(e): return (0, 0)
t, t_k = e[i]
if t.is_constant:
assert i == len(e) - 1 # Must be last
return (t_k, t_k)
lb = -np.inf
ub = np.inf
for (c_eq, c, c_s, t_s) in self.combine_term_with_existing(t, t_k,
only_smaller_than_t=True,
scope=scope):
# `c =[eq] 0` AND `t*t_k*t_s + c*c_s` contains only terms smaller than t
# AND c_s > 0.
# `rest = e[i:]*t_s + c*c_s` AND `rest_ub >= rest >= rest_lb`
# `rest` contains only terms smaller than `t`.
rest = _DimExpr._linear_combination_sorted_pairs(e, i, t_s,
c._sorted_terms, 0, c_s)
rest_lb, rest_ub = self._bounds_for_sorted_terms(scope, rest, 0,
BoundsPrecision.BEST)
if rest_ub < np.inf:
# We have: e[i:]*t_s = rest - c*c_s <= rest_ub
if t_s > 0:
ub = min(ub, int(np.floor(rest_ub / t_s)))
else:
lb = max(lb, int(np.ceil(rest_ub / t_s)))
if rest_lb > - np.inf and c_eq == Comparator.EQ:
# We have: e[i:]*t_s = rest - c*c_s = rest >= rest_lb
if t_s > 0:
lb = max(lb, int(np.ceil(rest_lb / t_s)))
else:
ub = min(ub, int(np.floor(rest_lb / t_s)))
if prec._bounds_are_sufficient(lb, ub): return (lb, ub)
# Now look for special rules for factors
if (t_f := t.to_factor()) is not None:
if t_f.operation in [_DimFactor.MAX, _DimFactor.MIN]:
# m_c*MAX(op1, op2) + rest_e >= max(m_c * op1 + rest_e, m_c * op2 + rest_e)
# if m_c > 0. Similar rules for when m_c < 0 and for MIN.
op1, op2 = t_f.operands
rest1 = _DimExpr._linear_combination_sorted_pairs(e, i + 1, 1,
op1._sorted_terms, 0, t_k)
rest2 = _DimExpr._linear_combination_sorted_pairs(e, i + 1, 1,
op2._sorted_terms, 0, t_k)
rest1_lb, rest1_ub = self._bounds_for_sorted_terms(scope, rest1, 0,
BoundsPrecision.BEST)
rest2_lb, rest2_ub = self._bounds_for_sorted_terms(scope, rest2, 0,
BoundsPrecision.BEST)
like_max = (t_k > 0 if t_f.operation == _DimFactor.MAX else t_k < 0)
if like_max:
lb = max(lb, max(rest1_lb, rest2_lb))
ub = min(ub, max(rest1_ub, rest2_ub))
else:
lb = max(lb, min(rest1_lb, rest2_lb))
ub = min(ub, min(rest1_ub, rest2_ub))
if prec._bounds_are_sufficient(lb, ub, ): return (lb, ub)
return lb, ub
def add_implicit_constraints_expr(self, e: _DimExpr):
"""Adds the implicit constraints for the expression `e`"""
for t, _ in e._sorted_terms:
if t.is_constant: continue
self.add_implicit_constraints_term(t)
def add_implicit_constraints_term(self, t: _DimTerm):
if t in self._processed_for_internal_constraints: return
self._processed_for_internal_constraints.add(t)
t_e = _DimExpr._from_term(t, 1, self.scope) # m as a _DimExpr
f = t.to_factor()
if f is None:
# This is a multiplication of factors. Try to compute bounds based on
# the bounds of the factors.
bounds = []
for f1, f1_exp in t._factors:
f1_t = _DimTerm.from_factor(f1, 1)
f1_e = _DimExpr._from_term(f1_t, 1, self.scope)
self.add_implicit_constraints_term(f1_t)
a1_l, a1_u = self.bounds(f1_e, BoundsPrecision.BEST)
assert a1_l <= a1_u
bounds.append((a1_l ** f1_exp, a1_u ** f1_exp))
candidate_bounds = [math.prod(factor_bounds)
for factor_bounds in itertools.product(*bounds)]
m_l = min(*candidate_bounds)
m_u = max(*candidate_bounds)
self.combine_and_add_constraint(Comparator.GEQ, t_e, m_l)
self.combine_and_add_constraint(Comparator.GEQ, m_u, t_e)
return
# It is a factor, is it a variable?
if f.to_var() is not None:
self.combine_and_add_constraint(Comparator.GEQ, t_e, 1) # f.to_var() >= 1
return
for oper in f.operands:
self.add_implicit_constraints_expr(oper)
if f.operation == _DimFactor.MOD:
op1, op2 = f.operands
op2_b_l, op2_b_u = self.bounds(op2, BoundsPrecision.FOR_GEQ0_OR_LT0)
if op2_b_l > 0: # positive divisor
self.combine_and_add_constraint(Comparator.GEQ, t_e, 0) # m >= 0
self.combine_and_add_constraint(Comparator.GEQ, op2 - 1, t_e) # m <= op2 - 1
self.combine_and_add_constraint(Comparator.GEQ, op2_b_u - 1, t_e)
elif op2_b_u < 0: # negative divisor
self.combine_and_add_constraint(Comparator.GEQ, t_e, op2 + 1) # m >= op2 + 1
self.combine_and_add_constraint(Comparator.GEQ, t_e, op2_b_l + 1)
self.combine_and_add_constraint(Comparator.GEQ, 0, t_e) # m <= 0
return
if f.operation == _DimFactor.FLOORDIV:
op1, op2 = f.operands
(op1_l, op1_u) = self.bounds(op1, BoundsPrecision.BEST)
(op2_l, op2_u) = self.bounds(op2, BoundsPrecision.BEST)
def math_floor_with_inf(a: float, b: float):
# math.floor(a / b), but aware of inf.
# When either a or b are infinite, the result represents the limit
# of "a // b".
assert b != 0 # we caught division by 0 earlier
if not np.isinf(b): # divisor b is finite
if not np.isinf(a): # both dividend a and divisor b are finite
return math.floor(a / b)
# a is infinite, b is finite
return -np.inf if (a >= 0) != (b >= 0) else np.inf
elif not np.isinf(a): # dividend a is finite and divisor b is infinite
return -1 if (a >= 0) != (b >= 0) else 0
else: # both dividend and divisor are infinite
return -np.inf if (a >= 0) != (b >= 0) else np.inf
# Same reasoning as for multiplication: the bounds are among the cross-product
# of the bounds.
if op2_l <= 0 <= op2_u:
raise InconclusiveDimensionOperation(
f"Possible division by 0 in division by {op2}")
candidate_bounds = [math_floor_with_inf(op1_l, op2_l),
math_floor_with_inf(op1_l, op2_u),
math_floor_with_inf(op1_u, op2_l),
math_floor_with_inf(op1_u, op2_u)]
m_l = min(*candidate_bounds)
m_u = max(*candidate_bounds)
self.combine_and_add_constraint(Comparator.GEQ, t_e, m_l)
self.combine_and_add_constraint(Comparator.GEQ, m_u, t_e)
if op2_l >= 0:
if op1_l >= 0:
self.combine_and_add_constraint(Comparator.GEQ, t_e, 0)
mod_e = _DimExpr._from_operation(_DimFactor.MOD, op1, op2,
scope=self.scope)
if isinstance(mod_e, _DimExpr):
self.add_implicit_constraints_expr(mod_e)
combined = op2 * t_e + mod_e
self.combine_and_add_constraint(Comparator.EQ, op1, combined)
return
if f.operation == _DimFactor.MAX:
op1, op2 = f.operands
op1_b_l, op1_b_u = self.bounds(op1, BoundsPrecision.BEST)
op2_b_l, op2_b_u = self.bounds(op2, BoundsPrecision.BEST)
self.combine_and_add_constraint(Comparator.GEQ, t_e, max(op1_b_l, op2_b_l))
self.combine_and_add_constraint(Comparator.GEQ, max(op1_b_u, op2_b_u), t_e)
self.combine_and_add_constraint(Comparator.GEQ, t_e, op1)
self.combine_and_add_constraint(Comparator.GEQ, t_e, op2)
return
if f.operation == _DimFactor.MIN:
op1, op2 = f.operands
op1_b_l, op1_b_u = self.bounds(op1, BoundsPrecision.BEST)
op2_b_l, op2_b_u = self.bounds(op2, BoundsPrecision.BEST)
self.combine_and_add_constraint(Comparator.GEQ, t_e, min(op1_b_l, op2_b_l))
self.combine_and_add_constraint(Comparator.GEQ, min(op1_b_u, op2_b_u), t_e)
self.combine_and_add_constraint(Comparator.GEQ, op1, t_e)
self.combine_and_add_constraint(Comparator.GEQ, op2, t_e)
return