This commit is contained in:
2026-05-06 19:47:31 +07:00
parent 94d8682530
commit 12dbb7731b
9963 changed files with 2747894 additions and 0 deletions
@@ -0,0 +1,22 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Colocated Python API."""
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax.experimental.colocated_python.api import (
colocated_cpu_devices as colocated_cpu_devices,
colocated_python as colocated_python,
colocated_python_class as colocated_python_class,
)
@@ -0,0 +1,192 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Colocated Python top-level API."""
from __future__ import annotations
import collections
from typing import Any, overload
from collections.abc import Callable, Sequence
import jax
from jax._src import api_util
from jax._src import util
from jax.experimental.colocated_python.func import make_callable
from jax.experimental.colocated_python.obj import wrap_class
import numpy as np
@overload
def colocated_cpu_devices(
devices_or_mesh: Sequence[jax.Device],
) -> Sequence[jax.Device]:
...
@overload
def colocated_cpu_devices(
devices_or_mesh: jax.sharding.Mesh,
) -> jax.sharding.Mesh:
...
def colocated_cpu_devices(devices_or_mesh):
"""Finds devices or a mesh that has CPU devices colocated with the given devices or mesh.
An accelerator device often accompanies a CPU device that is on the same host.
Furthermore, when a single host has multiple accelerator devices, there can be
multiple CPU devices, each of which is associated with one of the accelerator
devices with a 1:1 correspondence.
This function finds the colocated CPU devices for the given devices or mesh.
When the input is a mesh, the returned value is another mesh that has the same
shape as the input mesh but has colocated CPU devices. If an input device is
already a CPU device, it is returned as-is.
It preserves ordering. The output CPU device at index i is associated with the
input accelerator device at index i.
Args:
devices_or_mesh: A tuple of devices or a mesh.
Returns:
A tuple of devices or a mesh that has the colocated CPU devices.
"""
if isinstance(devices_or_mesh, jax.sharding.Mesh):
return _colocated_cpu_mesh_cached(devices_or_mesh)
if not isinstance(devices_or_mesh, tuple):
devices_or_mesh = tuple(devices_or_mesh)
try:
return _colocated_cpu_devices_cached(devices_or_mesh)
except (ValueError, AttributeError):
return _colocated_cpu_devices_cached_fallback_to_cpu_backend(
devices_or_mesh
)
@util.cache(max_size=1024, trace_context_in_key=False)
def _colocated_cpu_devices_cached(
devices: tuple[jax.Device, ...],
) -> Sequence[jax.Device]:
cpu_devices_by_colocation_id = collections.defaultdict(list)
for device in devices[0].client._get_all_devices():
if device.device_kind == "cpu":
cpu_devices_by_colocation_id[device.colocation_id].append(device)
if not cpu_devices_by_colocation_id:
raise ValueError("No CPU devices found")
colocated_cpu_devices = []
for device in devices:
matches = cpu_devices_by_colocation_id[device.colocation_id]
if not matches:
raise ValueError(f"Device {device} has no colocated devices")
elif len(matches) > 1:
raise ValueError(
f"Ambiguous colocated devices; device {device} has"
f" {len(matches)} colocated devices: f{matches}"
)
colocated_cpu_devices.append(matches[0])
return colocated_cpu_devices
@util.cache(max_size=1024, trace_context_in_key=False)
def _colocated_cpu_devices_cached_fallback_to_cpu_backend(
devices: tuple[jax.Device, ...],
) -> Sequence[jax.Device]:
# TODO(hyeontaek): Remove this fallback path once a PjRt-IFRT backend defines
# CPU devices by its own instead of using a separate CPU backend.
if devices[0].device_kind == "cpu":
# Use the devices from the backend of an original device if it defines CPU
# devices.
cpu_backend_devices = [d for d in devices[0].client._get_all_devices()
if d.device_kind == "cpu"]
else:
# PjRt-IFRT on a non-CPU platform currently defines CPU devices on a separae
# CPU backend.
cpu_backend_devices = jax.devices(backend="cpu")
cpu_device_map = collections.defaultdict(list)
for d in cpu_backend_devices:
cpu_device_map[d.process_index].append(d)
# Reverse each local CPU device list to make it cheaper to pop.
for process_index in cpu_device_map.keys():
cpu_device_map[process_index].reverse()
cpu_devices = []
for d in devices:
try:
cpu_devices.append(cpu_device_map[d.process_index].pop())
except IndexError:
raise ValueError(
f"Process {d.process_index} does not have enough local CPU devices")
return cpu_devices
@util.cache(max_size=1024, trace_context_in_key=False)
def _colocated_cpu_mesh_cached(mesh: jax.sharding.Mesh) -> jax.sharding.Mesh:
"""Returns a CPU mesh that is similar to the given mesh but has colocated CPU devices."""
# Finding colocated CPU devices reuses the cache of `colocated_cpu_devices`
# called with devices. `_colocated_cpu_mesh` itself is also cached to avoid
# creating a new `Mesh` object repeatedly.
flat_cpu_devices = colocated_cpu_devices(tuple(mesh.devices.flat))
return jax.sharding.Mesh(
np.array(flat_cpu_devices).reshape(mesh.axis_sizes),
mesh.axis_names,
axis_types=mesh.axis_types,
)
def colocated_python(fun: Callable[..., Any]):
"""Executes the given Python function on the same devices as the arguments.
The returned colocated Python callable lets the user run a serializable Python
function on the same devices as the arguments, potentially on remote hosts.
Python callable implements `specialize` and `__call__` methods. See their
docstrings for details and https://docs.jax.dev/en/latest/notebooks/colocated-python.html
for examples.
Args:
fun: An original function to wrap as an I/O callable.
Returns:
Colocated Python callable with no initial specialization.
"""
return make_callable(
fun, api_util.fun_sourceinfo(fun), api_util.fun_signature(fun)
)
def colocated_python_class(cls: type[object]) -> type[object]:
"""Creates a wrapper class that executes the given Python class methods on the same devices as the arguments.
The wrapper class exposes the returned type's methods, and can be instantiated
on JAX. An actual object will be instantiated on the host of the devices of
the arguments' when a method of the wrapper instance is called for the first
time.
The actual object will persist while the wrapper object is alive, and will be
destroyed asynchronously when the wrapper object is destroyed. Note that if
the wrapper object is destroyed immediately without any method call, actual
objects will not be created.
Args:
cls: The class to wrap as a colocated Python object.
Returns:
Wrapper class.
"""
return wrap_class(cls, api_util.fun_sourceinfo(cls))
@@ -0,0 +1,709 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Colocated Python function API implementation."""
from __future__ import annotations
from collections.abc import Callable, Sequence
import dataclasses
import inspect
import random
import threading
from typing import Any
import uuid
import weakref
import jax
from jax._src import api
from jax._src import tree_util
from jax._src import util
from jax._src.interpreters import pxla
from jax._src.lib import xla_client as xc
from jax._src.traceback_util import api_boundary
from jax._src.util import wraps
from jax.experimental.colocated_python import func_backend
from jax.experimental.colocated_python.serialization import _deserialize, _deserialize_specs, _make_specs_for_serialized_specs, _serialize, _serialize_specs
from jax.extend.backend import register_backend_cache as jax_register_backend_cache
from jax.extend.ifrt_programs import ifrt_programs
ShapeDtypeStructTree = Any # PyTree[api.ShapeDtypeStruct]
@dataclasses.dataclass(frozen=True, slots=True)
class FunctionInfo:
"""User function wrapped by colocated_python."""
fun: Callable[..., Any]
fun_sourceinfo: str | None
fun_signature: inspect.Signature | None
@dataclasses.dataclass(frozen=True, slots=True)
class Specialization:
"""Specialization for a colocated_python function."""
in_specs_treedef: tree_util.PyTreeDef | None = None
in_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None
out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None
out_specs_treedef: tree_util.PyTreeDef | None = None
out_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None
devices: xc.DeviceList | None = None
def update(
self,
*,
in_specs_treedef: tree_util.PyTreeDef | None = None,
in_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None,
out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None,
out_specs_treedef: tree_util.PyTreeDef | None = None,
out_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None,
devices: Sequence[jax.Device] | xc.DeviceList | None = None,
):
"""Creates a new specialization with overrides."""
if in_specs_treedef is None:
in_specs_treedef = self.in_specs_treedef
elif self.in_specs_treedef is not None:
raise ValueError("in_specs already specified")
if in_specs_leaves is None:
in_specs_leaves = self.in_specs_leaves
elif self.in_specs_leaves is not None:
raise ValueError("in_specs already specified")
if out_specs_fn is None:
out_specs_fn = self.out_specs_fn
elif self.out_specs_fn is not None:
raise ValueError("out_specs_fn already specified")
if out_specs_treedef is None:
out_specs_treedef = self.out_specs_treedef
elif self.out_specs_treedef is not None:
raise ValueError("out_specs already specified")
if out_specs_leaves is None:
out_specs_leaves = self.out_specs_leaves
elif self.out_specs_leaves is not None:
raise ValueError("out_specs already specified")
if devices is None:
devices = self.devices
elif self.devices is not None:
raise ValueError("devices already specified")
elif not isinstance(devices, xc.DeviceList):
devices = xc.DeviceList(tuple(devices))
return Specialization(
in_specs_treedef,
in_specs_leaves,
out_specs_fn,
out_specs_treedef,
out_specs_leaves,
devices,
)
def _get_spec(x: Any) -> api.ShapeDtypeStruct:
"""Extracts a spec for a value, which must be a JAX Array."""
# TODO(hyeontaek): Allow Python values and automatically apply `shard_arg`
# with a suitable sharding and layout.
if not isinstance(x, jax.Array):
raise ValueError(
"colocated_python only supports jax.Array as input and output, but got"
f" {type(x)}."
)
return api.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding)
def _infer_devices_from_args(args: Sequence[Any]) -> xc.DeviceList | None:
"""Returns a representative device list from function call arguments."""
device_list_set: set[xc.DeviceList] = set()
for x in args:
sharding = getattr(x, "sharding", None)
if sharding is not None:
device_list_set.add(x.sharding._internal_device_list)
if not device_list_set:
return None
if len(device_list_set) != 1:
raise ValueError(
"All arguments must use the same device list, but got"
f" multiple device lists: {device_list_set}."
)
return device_list_set.pop()
def _compile_to_executable(
name: str,
fun: Callable[..., Any],
in_specs_treedef: tree_util.PyTreeDef,
in_specs_leaves: tuple[api.ShapeDtypeStruct, ...],
out_specs_treedef: tree_util.PyTreeDef,
out_specs_leaves: tuple[api.ShapeDtypeStruct, ...],
devices: xc.DeviceList,
) -> Callable[..., Any]:
"""Compiles a Python function into a runtime executable."""
fun_and_specialization = (
fun,
in_specs_treedef,
in_specs_leaves,
out_specs_treedef,
out_specs_leaves,
devices,
)
pickled_function = _serialize(fun_and_specialization)
program = ifrt_programs.make_colocated_python_program(
name, pickled_function, devices, in_specs_leaves, out_specs_leaves
)
ifrt_client = devices[0].client
out_sdss = tuple(
jax.core.ShapedArray(sds.shape, sds.dtype) for sds in out_specs_leaves
)
out_shardings = tuple(sds.sharding for sds in out_specs_leaves)
try:
compile_options = ifrt_programs.make_colocated_python_compile_options()
loaded_executable = ifrt_client.compile_ifrt_program(
program, compile_options
)
out_handlers = pxla.global_avals_to_results_handler(
out_sdss, out_shardings, committed=True
).handlers
def call(*args, **kwargs):
args_leaves = tree_util.tree_leaves((args, kwargs))
execute_result = loaded_executable.execute_sharded(
args_leaves, with_tokens=False
)
results = execute_result.consume_with_handlers(out_handlers)
return tree_util.tree_unflatten(out_specs_treedef, results)
return call
except jax.errors.JaxRuntimeError as e:
# TODO(hyeontaek): Implement colocated Python support in McJAX and remove
# this fallback path.
if "PjRtCompiler requires an HloProgram" in str(e):
return _deserialize(pickled_function)[0]
raise
def _make_output_specs_and_push_result_fun(
info: FunctionInfo,
specialization: Specialization,
uid: int,
) -> Callable[..., Any]:
"""Creates a function that computes output specs and pushes the result to the result store."""
assert specialization.in_specs_treedef is not None
assert specialization.in_specs_leaves is not None
assert specialization.out_specs_treedef is None
assert specialization.out_specs_leaves is None
assert specialization.devices is not None
devices = specialization.devices
def lowered_fun(*args, **kwargs) -> jax.Array:
result = info.fun(*args, **kwargs)
result_leaves, out_treedef = tree_util.tree_flatten(result)
out_spec_leaves = tuple(_get_spec(x) for x in result_leaves)
func_backend.SINGLETON_RESULT_STORE.push(uid, result_leaves)
return _serialize_specs(out_treedef, out_spec_leaves, devices)
out_specs_leaves, out_specs_treedef = tree_util.tree_flatten(
_make_specs_for_serialized_specs(specialization.devices),
)
name = getattr(info.fun, "__name__", "unknown")
name = f"{name}_output_specs_and_push_result"
return _compile_to_executable(
name=name,
fun=lowered_fun,
in_specs_treedef=specialization.in_specs_treedef,
in_specs_leaves=specialization.in_specs_leaves,
out_specs_treedef=out_specs_treedef,
out_specs_leaves=tuple(out_specs_leaves),
devices=specialization.devices,
)
def _make_pop_result_fun(
info: FunctionInfo,
specialization: Specialization,
uid: int,
) -> Callable[..., Any]:
"""Makes a function that pops results from the result store."""
assert specialization.out_specs_treedef is not None
assert specialization.out_specs_leaves is not None
assert specialization.devices is not None
out_specs_treedef = specialization.out_specs_treedef
def lowered_fun():
result_leaves = func_backend.SINGLETON_RESULT_STORE.pop(uid)
return tree_util.tree_unflatten(out_specs_treedef, result_leaves)
in_specs_leaves, in_specs_treedef = tree_util.tree_flatten((
# args
(),
# kwargs
{},
))
name = getattr(info.fun, "__name__", "unknown")
name = f"{name}_pop_result"
return _compile_to_executable(
name=name,
fun=lowered_fun,
in_specs_treedef=in_specs_treedef,
in_specs_leaves=tuple(in_specs_leaves),
out_specs_treedef=specialization.out_specs_treedef,
out_specs_leaves=specialization.out_specs_leaves,
devices=specialization.devices,
)
def _make_async_execution_fun(
info: FunctionInfo,
specialization: Specialization,
) -> Callable[..., Any]:
"""Makes a function that asynchronously executes the function."""
assert specialization.in_specs_treedef is not None
assert specialization.in_specs_leaves is not None
assert specialization.out_specs_treedef is not None
assert specialization.out_specs_leaves is not None
assert specialization.devices is not None
name = getattr(info.fun, "__name__", "unknown")
return _compile_to_executable(
name=name,
fun=info.fun,
in_specs_treedef=specialization.in_specs_treedef,
in_specs_leaves=specialization.in_specs_leaves,
out_specs_treedef=specialization.out_specs_treedef,
out_specs_leaves=specialization.out_specs_leaves,
devices=specialization.devices,
)
def _uncached_get_specialized_func(
info: FunctionInfo,
specialization: Specialization,
) -> Callable[..., Any]:
"""Returns a specialized function for the given specialization."""
util.test_event("colocated_python_func._get_specialized_func")
assert specialization.in_specs_treedef is not None
assert specialization.in_specs_leaves is not None
assert specialization.devices is not None
uid = random.getrandbits(63)
mutex = threading.Lock()
# Asynchronous execution function that has known output_specs.
async_execution_func = None
def specialized_func(*args, **kwargs):
"""Specialized function to be executed with given args and kwargs."""
nonlocal specialization, async_execution_func
with mutex:
if async_execution_func is None:
if specialization.out_specs_treedef is None:
if specialization.out_specs_fn is None:
output_specs_and_push_result_fun = (
_make_output_specs_and_push_result_fun(
info, specialization, uid
)
)
serialized_out_specs = output_specs_and_push_result_fun(
*args, **kwargs
)
# Waits for the output_specs. This may block.
out_specs_treedef, out_specs_leaves = _deserialize_specs(
serialized_out_specs
)
# Subsequent calls would use async_execution_func with discovered
# output_specs.
specialization = specialization.update(
out_specs_treedef=out_specs_treedef,
out_specs_leaves=out_specs_leaves,
)
async_execution_func = _make_async_execution_fun(
info, specialization
)
# Hold the PyExecutable until async_execution_fun is called at
# least once, so the number of _OBJECT_STORE references at the
# backend does not drop to 0.
async_execution_func.output_specs_and_push_result_fun = ( # pyrefly: ignore[missing-attribute]
output_specs_and_push_result_fun
)
return _make_pop_result_fun(info, specialization, uid)()
else:
# Compute out_specs using out_specs_fn and inputs.
args_specs, kwargs_specs = tree_util.tree_map(
_get_spec, (args, kwargs)
)
out_specs = specialization.out_specs_fn(*args_specs, **kwargs_specs)
out_specs_leaves, out_specs_treedef = tree_util.tree_flatten(
out_specs
)
specialization = specialization.update(
out_specs_treedef=out_specs_treedef,
out_specs_leaves=tuple(out_specs_leaves),
)
async_execution_func = _make_async_execution_fun(
info, specialization
)
# Fall-through.
else:
async_execution_func = _make_async_execution_fun(info, specialization)
# Fall-through.
# Asynchronous execution runs outside of the mutex to allow concurrent
# execution for inline executors.
result = async_execution_func(*args, **kwargs)
with mutex:
async_execution_func.output_specs_and_push_result_fun = None # pyrefly: ignore[missing-attribute]
return result
return specialized_func
class _SpecializedCollection:
"""Collection of specialized functions for a single unspecialized function.
The `get()` method retrieves the specialized function for the provided input
spec, either by looking up a cache or by compiling the specialized function.
Looking up a cache with an input spec as a key can be slow, because
`Sharding`'s equivalence comparison is slow. Instead, we maintain two caches
for the same value: we use the ID of the sharding object (via `WeakSpec`) as
the key in one cache, and the corresponding strong references to the sharding
object (via `StrongSpec`) as the key in another cache. Looking up the
`WeakSpec`-keyed cache is fast. Note that the ID integer in the `WeakSpec`
cache will remain valid as long as a strong-ref exists in the `StrongSpec`
cache.
The `StrongSpec`-keyed cache is unbounded, while the `WeakSpec`-keyed cache
is LRU(1): if there is a miss in the `WeakSpec` cache but a hit in the
`StrongSpec` cache, the strong-ref is the `StrongSpec` cache and the ID
integer in the `WeakSpec` cache are both updated.
"""
@dataclasses.dataclass(slots=True, unsafe_hash=True)
class WeakSpec:
"""WeakSpec stores just the `id()` of the input spec sharding."""
dtypes: tuple[jax.numpy.dtype, ...]
shapes: tuple[tuple[int, ...], ...]
sharding_ids: tuple[int, ...]
treedef: tree_util.PyTreeDef
def __init__(
self, args_leaves: Sequence[jax.Array], treedef: tree_util.PyTreeDef
):
self.dtypes = tuple(x.dtype for x in args_leaves)
self.shapes = tuple(x.shape for x in args_leaves)
self.sharding_ids = tuple(id(x.sharding) for x in args_leaves)
self.treedef = treedef
@dataclasses.dataclass(slots=True, unsafe_hash=True)
class StrongSpec:
"""StrongSpec stores the full input spec sharding."""
in_specs_treedef: tree_util.PyTreeDef | None = None
in_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None
def __init__(
self, args_leaves: Sequence[jax.Array], pytreedef: tree_util.PyTreeDef
):
self.in_specs_leaves = tuple(_get_spec(x) for x in args_leaves)
self.in_specs_treedef = pytreedef
def __init__(self):
CompiledId = int
self._weak_to_id: dict[_SpecializedCollection.WeakSpec, CompiledId] = {}
self._id_to_weak: dict[CompiledId, _SpecializedCollection.WeakSpec] = {}
self._strong_to_id: dict[_SpecializedCollection.StrongSpec, CompiledId] = {}
self._id_to_compiled: dict[CompiledId, Callable[..., Any]] = {}
self._counter = 0
self._mu = threading.Lock()
def get(
self,
args_leaves: Sequence[jax.Array],
pytreedef: tree_util.PyTreeDef,
func_info: FunctionInfo,
specialization: Specialization,
) -> Callable[..., Any]:
# TODO(hyeontaek): Allow Python values in args_leaves, similar to the todo
# in _get_spec().
# Attempt fast-path cache hit.
weak_spec = _SpecializedCollection.WeakSpec(args_leaves, pytreedef)
compiled_id = self._weak_to_id.get(weak_spec)
if compiled_id is not None:
return self._id_to_compiled[compiled_id]
with self._mu:
# Attempt slow-path cache hit.
strong_spec = _SpecializedCollection.StrongSpec(args_leaves, pytreedef)
compiled_id = self._strong_to_id.pop(strong_spec, None)
if compiled_id is not None:
# Update the caches so that the fast-path cache stores the `id()` of the
# shardings presented by the current invocation.
old_weak = self._id_to_weak.pop(compiled_id)
del self._weak_to_id[old_weak]
self._strong_to_id[strong_spec] = compiled_id
self._weak_to_id[weak_spec] = compiled_id
self._id_to_weak[compiled_id] = weak_spec
return self._id_to_compiled[compiled_id]
# Cache-miss: compile.
if specialization.devices is None:
result = _uncached_get_specialized_func(
func_info,
specialization.update(
in_specs_treedef=strong_spec.in_specs_treedef,
in_specs_leaves=strong_spec.in_specs_leaves,
devices=_infer_devices_from_args(args_leaves),
),
)
else:
result = _uncached_get_specialized_func(
func_info,
specialization.update(
in_specs_treedef=strong_spec.in_specs_treedef,
in_specs_leaves=strong_spec.in_specs_leaves,
),
)
compiled_id = self._counter
self._counter += 1
self._weak_to_id[weak_spec] = compiled_id
self._strong_to_id[strong_spec] = compiled_id
self._id_to_weak[compiled_id] = weak_spec
self._id_to_compiled[compiled_id] = result
return result
class _JaxSecondLevelCaches:
"""Manages second-level caches registered as a single cache with JAX."""
def __init__(self, name: str):
self._lock = threading.Lock()
self._callbacks: dict[int, Callable[..., Any]] = {}
jax_register_backend_cache(self, name)
def cache_clear(self):
"""Meant to be invoked by JAX internals."""
for callback in self._callbacks.values():
callback()
self._callbacks.clear()
def register_second_level(
self, uid: int, cache_clear_callback: Callable[..., Any]
):
self._callbacks[uid] = cache_clear_callback
def remove_second_level(self, uid: int):
try:
self._callbacks.pop(uid)
except KeyError:
pass
class _CachedColocatedFunctionMaker:
"""Function maker for colocated Python functions.
Generated functions are stored (cached) indefinitely so that they can be
reused, until the cache is dropped.
"""
JAX_CACHE = _JaxSecondLevelCaches("colocated_python_specialized_func_cache")
def __init__(self, held_by: int | None):
self.held_by = held_by if held_by is not None else uuid.uuid4().int
specialized_collections: list[_SpecializedCollection] = []
specialized_functions: list[Callable[..., Any]] = []
def clear_caches():
specialized_collections.clear()
specialized_functions.clear()
_CachedColocatedFunctionMaker.JAX_CACHE.register_second_level(
self.held_by,
clear_caches,
)
self.specialized_collections = specialized_collections
self.specialized_functions = specialized_functions
def __del__(self):
self.specialized_collections.clear()
self.specialized_functions.clear()
try:
_CachedColocatedFunctionMaker.JAX_CACHE.remove_second_level(self.held_by)
except AttributeError:
# Ignore error during python finalization.
pass
def _make_callable(
self,
info: FunctionInfo,
specialization: Specialization,
):
"""Internal implementation of make_callable."""
def specialize(
in_specs: ShapeDtypeStructTree | None = None,
out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None,
devices: Sequence[jax.Device] | None = None,
):
"""Returns a colocated Python callable with extra specialization.
Args:
in_specs: Optionally specifies the expected input specs. Input specs are
expressed as a `PyTree[ShapeDtypeStruct]` for `(args, kwargs)` of a
function call.
out_specs_fn: Optionally specifies a function that computes the output
specs from input specs. If unspecified, colocated Python will compute
the output specs during the very first execution, and this execution
will be synchronous.
devices: Optionally specifies the devices to execute the function on.
Must be provided if `in_specs` has no leaves because devices cannot be
inferred from input specs or arguments.
Returns:
A colocated Python callable with extra specialization.
"""
# TODO(hyeontaek): Allow unspecified devices for zero-leaf `in_specs` if
# `out_specs_fn(in_specs)` returns at least one leaf that we can use for
# inferring `devices`.
if in_specs is None:
in_specs_leaves, in_specs_treedef = None, None
else:
in_specs_leaves_list, in_specs_treedef = tree_util.tree_flatten(
in_specs
)
in_specs_leaves = tuple(in_specs_leaves_list)
return self._make_callable(
info,
specialization.update(
in_specs_treedef=in_specs_treedef,
in_specs_leaves=in_specs_leaves,
out_specs_fn=out_specs_fn,
devices=devices,
),
)
# Caches for a collection of specialized functions or a specialized function
# itself. The latter is used as a performance optimization when the input
# spec is explicitly specified and can skip a collection lookup. The caches
# use weakrefs so that we avoid creating cyclic references.
specialized_collections_wref: Callable[..., Any] = lambda: None
specialized_functions_wref: Callable[..., Any] = lambda: None
wref_mu = threading.Lock()
@api_boundary
def __call__(*args, **kwargs):
"""Executes the given Python function on the same devices as the arguments or as specialized.
If the callable has not been specialized with output shapes and shardings
(see `specialize` above), the very first call will run synchronously to
discover output shapes and shardings, and will run asynchronously after.
If specialized with output shapes and shardings, every execution of the
callable will be asynchronous.
"""
args_leaves, in_specs_treedef = tree_util.tree_flatten((args, kwargs))
no_input = len(args_leaves) == 0
if no_input and specialization.devices is None:
raise ValueError(
"No devices found. colocated_python function without input"
" arguments must be first specialized with devices."
)
fully_specified_in_spec = (
specialization.in_specs_treedef is not None
and specialization.in_specs_leaves is not None
)
if not fully_specified_in_spec and not no_input:
# We need to handle input polymorphism
nonlocal specialized_collections_wref
with wref_mu:
collection: _SpecializedCollection = specialized_collections_wref()
if collection is None:
collection = _SpecializedCollection()
self.specialized_collections.append(collection)
specialized_collections_wref = weakref.ref(collection)
result = collection.get(
args_leaves, in_specs_treedef, info, specialization
)(*args, **kwargs)
del collection
return result
# No input polymorphism -- exactly one compiled function is possible.
with wref_mu:
nonlocal specialized_functions_wref
func: Callable[..., Any] = specialized_functions_wref()
if func is None:
if fully_specified_in_spec and specialization.devices is not None:
func = _uncached_get_specialized_func(info, specialization)
elif fully_specified_in_spec:
func = _uncached_get_specialized_func(
info,
specialization.update(
devices=_infer_devices_from_args(args_leaves)
),
)
elif no_input:
func = _uncached_get_specialized_func(
info,
specialization.update(
in_specs_leaves=tuple(),
in_specs_treedef=in_specs_treedef,
),
)
self.specialized_functions.append(func)
specialized_functions_wref = weakref.ref(func)
result = func(*args, **kwargs)
del func
return result
__call__ = wraps(info.fun)(__call__)
__call__.specialize = specialize # pyrefly: ignore[missing-attribute]
return __call__
def make_callable(
self,
fun: Callable[..., Any],
fun_sourceinfo: str | None,
fun_signature: inspect.Signature | None,
):
"""Makes a colocated Python callable."""
return self._make_callable(
FunctionInfo(fun, fun_sourceinfo, fun_signature), Specialization()
)
_DEFAULT_FUNCTION_MAKER = _CachedColocatedFunctionMaker(None)
def make_callable(
fun: Callable[..., Any],
fun_sourceinfo: str | None,
fun_signature: inspect.Signature | None,
):
return _DEFAULT_FUNCTION_MAKER.make_callable(
fun, fun_sourceinfo, fun_signature
)
@@ -0,0 +1,44 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Backend for colocated_python.func."""
from __future__ import annotations
import threading
from collections.abc import Sequence
import jax
class _ResultStore:
"""Temporarily stores results from synchronous execution of functions."""
def __init__(self) -> None:
self._lock = threading.Lock()
self._storage: dict[int, Sequence[jax.Array]] = {}
def push(self, uid: int, out: Sequence[jax.Array]) -> None:
with self._lock:
if uid in self._storage:
raise ValueError(f"uid {uid} already exists")
self._storage[uid] = out
def pop(self, uid: int) -> Sequence[jax.Array]:
with self._lock:
if uid not in self._storage:
raise ValueError(f"uid {uid} does not exist")
return self._storage.pop(uid)
SINGLETON_RESULT_STORE = _ResultStore()
@@ -0,0 +1,202 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Colocated Python object API implementation."""
from __future__ import annotations
from collections.abc import Callable
import inspect
import random
import threading
from typing import Any
import jax
from jax._src import api_util
from jax._src import tree_util
from jax._src import util
from jax._src.traceback_util import api_boundary
from jax.experimental.colocated_python import func
from jax.experimental.colocated_python import obj_backend
class _InstanceRegistry:
"""Registry of object instances."""
def __init__(self) -> None:
self._lock = threading.Lock()
self._storage: dict[int, set[jax.Device]] = {}
def new_instance(self) -> int:
"""Returns a new unique identifier for an instance on the controller."""
uid = random.getrandbits(63)
with self._lock:
assert uid not in self._storage
self._storage[uid] = set()
return uid
def update_devices(self, uid: int, device_set: set[jax.Device]) -> None:
"""Updates the set of devices on which it is live."""
with self._lock:
self._storage[uid] |= device_set
def pop_instance(self, uid: int) -> set[jax.Device]:
"""Removes the instance and returns the set of devices on which it is live."""
with self._lock:
return self._storage.pop(uid)
SINGLETON_INSTANCE_REGISTRY = _InstanceRegistry()
@util.cache(max_size=4096)
def _update_instance_devices(
uid: int, shardings: tuple[jax.sharding.Sharding, ...]
) -> None:
"""Caching version of _InstanceRegistry.update_devices()."""
device_set = set()
for sharding in shardings:
device_set |= sharding.device_set
SINGLETON_INSTANCE_REGISTRY.update_devices(uid, device_set)
def _make_method(
cls: type[object],
cls_sourceinfo: str | None,
uid: int,
init_args: tuple[Any, ...],
init_kwargs: dict[str, Any],
method_name: str,
original_method: Callable[..., Any],
func_maker: func._CachedColocatedFunctionMaker,
):
class MethodCallerAtBackend:
def __init__(self):
self._lock = threading.Lock()
def __reduce__(self):
return type(self), ()
def _first_call(self):
def initializer():
return obj_backend._ConsumableRef(cls(*init_args, **init_kwargs))
retrieved = obj_backend.SINGLETON_OBJECT_STORE.get_or_create(
uid, initializer
)
self.obj = retrieved()
def __call__(self, *args, **kwargs):
with self._lock:
if not hasattr(self, 'obj'):
self._first_call()
return getattr(self.obj, method_name)(*args, **kwargs)
def __del__(self):
if not hasattr(self, 'obj'):
# It is possible that no one has ever consumed the _ConsumableRef. So
# consume it now.
obj_backend.SINGLETON_OBJECT_STORE.get_or_create(
uid, lambda: obj_backend._ConsumableRef(None)
)()
# Colocated Python callable for the controller.
callable = func_maker.make_callable(
MethodCallerAtBackend(),
cls_sourceinfo,
api_util.fun_signature(original_method),
)
# Outer wrapper of the method for the controller. It tracks devices that have
# been used with any method call.
def make_method_wrapper(callable):
@api_boundary
def method_wrapper(*args, **kwargs):
# TODO(hyeontaek): Instead of inspecting argument/result shardings, get
# shardings from final specialization of the function. This may require
# lowering `_update_instance_devices` into the function API.
args_leaves = tree_util.tree_leaves((args, kwargs))
args_shardings_leaves = tuple(
func._get_spec(x).sharding for x in args_leaves
)
if args_shardings_leaves:
_update_instance_devices(uid, args_shardings_leaves)
result = callable(*args, **kwargs)
# If args had any array, we can skip incorporating devices from the result
# because results will not use any new devices.
if not args_shardings_leaves:
result_leaves = tree_util.tree_leaves(result)
result_shardings_leaves = tuple(
func._get_spec(x).sharding for x in result_leaves
)
_update_instance_devices(uid, result_shardings_leaves)
return result
def specialize(*args, **kwargs):
return make_method_wrapper(callable.specialize(*args, **kwargs))
method_wrapper = util.wraps(original_method)(method_wrapper)
method_wrapper.specialize = specialize # pyrefly: ignore[missing-attribute]
return method_wrapper
method_wrapper = make_method_wrapper(callable)
return method_wrapper
def wrap_class(
cls: type[object],
cls_sourceinfo: str | None,
) -> type[object]:
class WrappedClass:
@util.wraps(cls.__init__)
def __init__(self, *init_args, **init_kwargs) -> None:
uid = self._colocated_python_uid = (
SINGLETON_INSTANCE_REGISTRY.new_instance()
)
self.func_maker = func._CachedColocatedFunctionMaker(uid)
for attr_name in dir(cls):
original_member = getattr(cls, attr_name)
if not inspect.isfunction(original_member):
continue
# WrappedClass defines lazy initialization and colocated deletion logic.
# WrappedClass is not serializable even if the original class may be
# serializable.
if attr_name in ('__init__', '__del__', '__reduce__', '__reduce_ex__'):
continue
method = _make_method(
cls,
cls_sourceinfo,
uid,
init_args,
init_kwargs,
attr_name,
original_member,
self.func_maker,
)
# TODO(hyeontaek): Support method specialization similar to function
# specialization.
setattr(self, attr_name, method)
WrappedClass.__name__ = cls.__name__
WrappedClass.__doc__ = cls.__doc__
return WrappedClass
@@ -0,0 +1,119 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Backend for colocated_python.obj."""
from __future__ import annotations
from collections.abc import Callable
import dataclasses
import threading
from typing import Any
import weakref
@dataclasses.dataclass
class _ConsumableRef:
"""Stores a strong ref initially, but switches to a weak ref once consumed.
We consider a _ConsumableRef to have been consumed once the __call__ method
has been called once, and the _ConsumableRef is no longer storing a strong
ref. We consider a _ConsumableRef to have expired once it has been consumed
and the resulting weak ref has expired.
Pickling and unpickling an unexpired _ConsumableRef will create a new,
unconsumed _ConsumableRef. Pickling and unpickling an expired _ConsumableRef
will create an expired _ConsumableRef.
"""
strong_ref: Any | None = None
weak_ref: weakref.ref | None = None
_mutex = threading.Lock()
def __init__(self, obj: Any) -> None:
self.strong_ref = obj
def __call__(self, *args, **kwargs):
with self._mutex:
if self.strong_ref is not None:
assert self.weak_ref is None
result = self.strong_ref
self.strong_ref = None
self.weak_ref = weakref.ref(result)
return result
elif self.weak_ref is not None:
return self.weak_ref()
else:
return None
def __reduce__(self):
with self._mutex:
if self.strong_ref is not None:
return type(self), (self.strong_ref,)
elif self.weak_ref is not None:
return type(self), (self.weak_ref(),)
else:
return type(self), (None,)
@dataclasses.dataclass(frozen=True)
class _ObjectState:
is_being_initialized: bool
exc: Exception | None = None
obj: Any = None
class _ObjectStore:
"""Stores live objects.
TODO(madthanu): Currently the dictionary never removes entries that are
expired refs.
"""
def __init__(self) -> None:
self._lock = threading.Condition()
self._storage: dict[int, _ObjectState] = {}
def get_or_create(self, uid: int, initializer: Callable[[], Any]) -> Any:
"""Returns the object associated with the given uid, or creates it if it does not exist."""
with self._lock:
if uid in self._storage:
while True:
state = self._storage[uid]
if state.is_being_initialized:
# Another thread is initializing the object. Wait for it to finish.
self._lock.wait()
else:
break
if state.exc is not None:
raise state.exc
return state.obj
self._storage[uid] = _ObjectState(is_being_initialized=True)
try:
obj = initializer()
except Exception as exc:
with self._lock:
self._storage[uid] = _ObjectState(is_being_initialized=False, exc=exc)
self._lock.notify_all()
raise
with self._lock:
self._storage[uid] = _ObjectState(is_being_initialized=False, obj=obj)
self._lock.notify_all()
return obj
SINGLETON_OBJECT_STORE = _ObjectStore()
@@ -0,0 +1,348 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Colocated Python serialization utilities."""
from __future__ import annotations
import base64
import collections
from collections.abc import Callable, Sequence
import functools
import io
import threading
from typing import Any
try:
import cloudpickle # pyrefly: ignore[missing-import]
except ImportError:
cloudpickle = None
import jax
from jax._src import api
from jax._src import tree_util
from jax._src import util
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
import numpy as np
DeviceList = xc.DeviceList
class _CommonObjectState(threading.local):
"""Tracks repeated objects within a single `_serialize()` or `_deserialize()`.
It is common for `_serialize(x)` to be called with `x` being a nested
container or capturing other objects in a closure, with many references
pointing to only a few unique objects. The logic below
(`_make_reduce_func_with_common_obj`) avoids duplicating object serialization
by reducing a reference handle instead of the full object when an equal object
is repeatedly seen.
"""
def __init__(self):
# Map from a common object key to its ID. Any objects with a matching key
# will use the common object ID instead of the full object during
# serialization.
self.common_obj_index: dict[Any, int] | None = None
# Common object that has been reconstructed when their key was seen for the
# first time during deserialization.
self.common_obj: list[Any] | None = None
_common_obj_state = _CommonObjectState()
def _wrapped_unreduce_func_with_new_common_obj(
common_obj_id, unreduce_func, unreduce_args):
"""Unreduces a new common object."""
assert _common_obj_state.common_obj is not None
obj = unreduce_func(*unreduce_args)
assert len(_common_obj_state.common_obj) == common_obj_id, (
f"Expected {common_obj_id} common objects, but got"
f" {len(_common_obj_state.common_obj)}. This can happen if serialization"
" and deserialization of objects happened in different orders."
)
_common_obj_state.common_obj.append(obj)
return obj
def _wrapped_unreduce_func_with_existing_common_obj(common_obj_id):
"""Unreduces a common object that has already appeared."""
assert _common_obj_state.common_obj is not None
return _common_obj_state.common_obj[common_obj_id]
def _make_reduce_func_with_common_obj(
reduce_func: Callable[[Any], tuple[Any, Any]],
) -> Callable[[Any], tuple[Any, Any]]:
"""Wraps a reduce function to serialize a common object once."""
@functools.wraps(reduce_func)
def wrapped_reduce_func(obj):
assert _common_obj_state.common_obj_index is not None
common_obj_id = _common_obj_state.common_obj_index.get(obj)
if common_obj_id is None:
unreduced_func, unreduced_args = reduce_func(obj)
common_obj_id = len(_common_obj_state.common_obj_index)
_common_obj_state.common_obj_index[obj] = common_obj_id
return _wrapped_unreduce_func_with_new_common_obj, (
common_obj_id, unreduced_func, unreduced_args)
else:
return _wrapped_unreduce_func_with_existing_common_obj, (common_obj_id,)
return wrapped_reduce_func
@util.cache(max_size=None)
def _get_cpu_device_map() -> dict[int, jax.Device]:
"""Returns a map from a device id to a matching device."""
cpu_device_map: dict[int, jax.Device] = {}
# TODO(hyeontaek): We should look up CPU devices for a specific CPU backend.
# When deserializing a device on the controller, the backend should be the one
# associated with colocated_python. When deserializing on the colocated_python
# executor, it should be the CPU backend visible to the user function running
# under colocated_python.
# Look for CPU devices in the default backend.
for d in xb.local_devices()[0].client._get_all_devices():
if d.device_kind == "cpu":
if d.id in cpu_device_map:
raise ValueError(
f"Multiple CPU devices with id {d.id} found:"
f" {cpu_device_map[d.id]} and {d}"
)
cpu_device_map[d.id] = d
if cpu_device_map:
return cpu_device_map
# Fall back to searching CPU devices in all backends.
for backend in xb.backends().values():
for d in backend._get_all_devices():
if d.device_kind == "cpu":
if d.id in cpu_device_map:
raise ValueError(
f"Multiple CPU devices with id {d.id} found:"
f" {cpu_device_map[d.id]} and {d}"
)
cpu_device_map[d.id] = d
return cpu_device_map
def _lookup_cpu_device(
cpu_device_map: dict[int, jax.Device], device_id: int
) -> jax.Device:
"""Returns a CPU device with the given device ID."""
d = cpu_device_map.get(device_id)
if d is None:
raise ValueError(
f"Invalid device ID {device_id}. Device list must contain only CPU"
" devices."
)
return d
@_make_reduce_func_with_common_obj
def _reduce_mesh(
mesh: jax.sharding.Mesh,
) -> tuple[Callable[..., jax.sharding.Mesh], Any]:
mesh_device_ids = np.vectorize(lambda d: d.id, otypes=[int])(mesh.devices)
return _unreduce_mesh, (mesh_device_ids, mesh.axis_names, mesh.axis_types)
def _unreduce_mesh(
mesh_device_ids: np.ndarray, axis_names: Any, axis_types: Any
) -> jax.sharding.Mesh:
cpu_device_map = _get_cpu_device_map()
mesh_devices = np.vectorize(
functools.partial(_lookup_cpu_device, cpu_device_map)
)(mesh_device_ids)
return jax.sharding.Mesh(mesh_devices, axis_names, axis_types)
@_make_reduce_func_with_common_obj
def _reduce_named_sharding(
sharding: jax.sharding.NamedSharding,
) -> tuple[Callable[..., jax.sharding.NamedSharding], Any]:
assert isinstance(sharding.mesh, jax.sharding.Mesh), "Only Mesh is supported"
reduced_mesh = _reduce_mesh(sharding.mesh)
return _unreduce_named_sharding, (
reduced_mesh, sharding.spec, sharding.memory_kind)
def _unreduce_named_sharding(reduced_mesh, spec, memory_kind):
mesh = reduced_mesh[0](*reduced_mesh[1])
return jax.NamedSharding(mesh, spec, memory_kind=memory_kind)
@_make_reduce_func_with_common_obj
def _reduce_device_list(
device_list: DeviceList,
) -> tuple[Callable[..., DeviceList], Any]:
device_ids = [d.id for d in device_list]
return _unreduce_device_list, (device_ids,)
def _unreduce_device_list(device_ids: Sequence[int]) -> DeviceList:
cpu_device_map = _get_cpu_device_map()
devices = np.vectorize(functools.partial(_lookup_cpu_device, cpu_device_map))(
device_ids)
return DeviceList(tuple(devices))
@_make_reduce_func_with_common_obj
def _reduce_single_device_sharding(
sharding: jax.sharding.SingleDeviceSharding,
) -> tuple[Callable[..., jax.sharding.SingleDeviceSharding], Any]:
return _unreduce_single_device_sharding, (
sharding.device_set.pop().id,
sharding.memory_kind)
def _unreduce_single_device_sharding(
device_id: int, memory_kind: str | None
) -> jax.sharding.SingleDeviceSharding:
cpu_device_map = _get_cpu_device_map()
device = _lookup_cpu_device(cpu_device_map, device_id)
return jax.sharding.SingleDeviceSharding(device, memory_kind=memory_kind)
def _serialize(obj: Any) -> bytes:
"""Serializes callables and input/output spec objects.
DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF
colocated_python.
This module contains utility functions used internally for implementiong
`colocated_python` when it ships callables and input/output specs through
IFRT. The pickled data is produced and consumed in an ephermeral fashion
without any persistence, and it does not expect any version compatibility
(which cloudpickle does not guarantee). Furthermore, serialization and
deserialization is expected to be done on machine(s) that are controlled by a
single tenant, which allows unpickling done during deserialization to be
trusted.
Raises:
ModuleNotFoundError: If cloudpickle is not available.
"""
if cloudpickle is None:
raise ModuleNotFoundError('No module named "cloudpickle"')
class _CustomPickler(cloudpickle.Pickler):
dispatch_table = collections.ChainMap(
{jax.sharding.Mesh: _reduce_mesh},
{jax.sharding.NamedSharding: _reduce_named_sharding}, # pyrefly: ignore[bad-argument-type]
{DeviceList: _reduce_device_list}, # pyrefly: ignore[bad-argument-type]
{jax.sharding.SingleDeviceSharding: _reduce_single_device_sharding}, # pyrefly: ignore[bad-argument-type]
cloudpickle.CloudPickler.dispatch_table, # pyrefly: ignore[bad-argument-type]
)
dispatch = dispatch_table
assert _common_obj_state.common_obj_index is None, (
"_serialize() expects no recursive calls")
_common_obj_state.common_obj_index = {}
try:
with io.BytesIO() as file:
_CustomPickler(file).dump(obj)
return file.getvalue()
finally:
_common_obj_state.common_obj_index = None
def _deserialize(serialized: bytes) -> Any:
"""Deserializes callables and input/output spec objects.
DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF
colocated_python. See serialize() for details.
Raises:
ModuleNotFoundError: If cloudpickle is not available.
"""
if cloudpickle is None:
raise ModuleNotFoundError('No module named "cloudpickle"')
assert _common_obj_state.common_obj is None, (
"_deserialize() expects no recursive calls")
_common_obj_state.common_obj = []
try:
return cloudpickle.loads(serialized)
finally:
_common_obj_state.common_obj = None
def _make_specs_for_serialized_specs(
devices: DeviceList,
) -> api.ShapeDtypeStruct:
"""Makes output specs for serialized specs."""
mesh = jax.sharding.Mesh(tuple(devices), ("x",))
replicated_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec()
)
return api.ShapeDtypeStruct(
shape=(), dtype=np.dtypes.StringDType(), sharding=replicated_sharding
)
def _serialize_specs(
specs_treedef: tree_util.PyTreeDef,
specs_leaves: tuple[api.ShapeDtypeStruct, ...],
devices: DeviceList,
) -> jax.Array:
"""Serializes the output specs into a jax.Array of string type.
DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF
colocated_python. See serialize() for details.
"""
if not hasattr(np.dtypes, "StringDType"):
raise TypeError(
"Serializing Colocated Python requires StringDType. Please use"
" numpy to 2.0.0 or later, or explicitly provide an output spec"
" function."
)
s_bytes = _serialize((specs_treedef, specs_leaves))
s_str = base64.b64encode(s_bytes).decode("ascii")
s_np_array = np.array(s_str, dtype=np.dtypes.StringDType())
# TODO(jmudigonda): Revisit this when JAX supports HLO sharding for making
# jax.Array via make_array_from_single_device_arrays. We should then use a
# sharding that spans all the execution devices - not just the addressable
# ones.
addressable_devices = devices.addressable_device_list
mesh = jax.sharding.Mesh(tuple(addressable_devices), ("x",))
replicated_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec()
)
out_arrays = [
jax.device_put(s_np_array, device) for device in addressable_devices
]
return jax.make_array_from_single_device_arrays(
arrays=out_arrays,
sharding=replicated_sharding,
shape=(),
)
def _deserialize_specs(
serialized_specs: jax.Array,
) -> tuple[tree_util.PyTreeDef, tuple[api.ShapeDtypeStruct, ...]]:
"""Deserializes the specs from the serialized specs.
DO NOT USE THIS FUNCTION EXCEPT FOR THE INTERNAL IMPLEMENTATION OF
colocated_python. See serialize() for details.
"""
data_array = serialized_specs.addressable_shards[0].data
data = base64.b64decode(data_array.item().encode("ascii"))
return _deserialize(data)