Files
2026-05-06 19:47:31 +07:00

203 lines
6.3 KiB
Python

# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Colocated Python object API implementation."""
from __future__ import annotations
from collections.abc import Callable
import inspect
import random
import threading
from typing import Any
import jax
from jax._src import api_util
from jax._src import tree_util
from jax._src import util
from jax._src.traceback_util import api_boundary
from jax.experimental.colocated_python import func
from jax.experimental.colocated_python import obj_backend
class _InstanceRegistry:
"""Registry of object instances."""
def __init__(self) -> None:
self._lock = threading.Lock()
self._storage: dict[int, set[jax.Device]] = {}
def new_instance(self) -> int:
"""Returns a new unique identifier for an instance on the controller."""
uid = random.getrandbits(63)
with self._lock:
assert uid not in self._storage
self._storage[uid] = set()
return uid
def update_devices(self, uid: int, device_set: set[jax.Device]) -> None:
"""Updates the set of devices on which it is live."""
with self._lock:
self._storage[uid] |= device_set
def pop_instance(self, uid: int) -> set[jax.Device]:
"""Removes the instance and returns the set of devices on which it is live."""
with self._lock:
return self._storage.pop(uid)
SINGLETON_INSTANCE_REGISTRY = _InstanceRegistry()
@util.cache(max_size=4096)
def _update_instance_devices(
uid: int, shardings: tuple[jax.sharding.Sharding, ...]
) -> None:
"""Caching version of _InstanceRegistry.update_devices()."""
device_set = set()
for sharding in shardings:
device_set |= sharding.device_set
SINGLETON_INSTANCE_REGISTRY.update_devices(uid, device_set)
def _make_method(
cls: type[object],
cls_sourceinfo: str | None,
uid: int,
init_args: tuple[Any, ...],
init_kwargs: dict[str, Any],
method_name: str,
original_method: Callable[..., Any],
func_maker: func._CachedColocatedFunctionMaker,
):
class MethodCallerAtBackend:
def __init__(self):
self._lock = threading.Lock()
def __reduce__(self):
return type(self), ()
def _first_call(self):
def initializer():
return obj_backend._ConsumableRef(cls(*init_args, **init_kwargs))
retrieved = obj_backend.SINGLETON_OBJECT_STORE.get_or_create(
uid, initializer
)
self.obj = retrieved()
def __call__(self, *args, **kwargs):
with self._lock:
if not hasattr(self, 'obj'):
self._first_call()
return getattr(self.obj, method_name)(*args, **kwargs)
def __del__(self):
if not hasattr(self, 'obj'):
# It is possible that no one has ever consumed the _ConsumableRef. So
# consume it now.
obj_backend.SINGLETON_OBJECT_STORE.get_or_create(
uid, lambda: obj_backend._ConsumableRef(None)
)()
# Colocated Python callable for the controller.
callable = func_maker.make_callable(
MethodCallerAtBackend(),
cls_sourceinfo,
api_util.fun_signature(original_method),
)
# Outer wrapper of the method for the controller. It tracks devices that have
# been used with any method call.
def make_method_wrapper(callable):
@api_boundary
def method_wrapper(*args, **kwargs):
# TODO(hyeontaek): Instead of inspecting argument/result shardings, get
# shardings from final specialization of the function. This may require
# lowering `_update_instance_devices` into the function API.
args_leaves = tree_util.tree_leaves((args, kwargs))
args_shardings_leaves = tuple(
func._get_spec(x).sharding for x in args_leaves
)
if args_shardings_leaves:
_update_instance_devices(uid, args_shardings_leaves)
result = callable(*args, **kwargs)
# If args had any array, we can skip incorporating devices from the result
# because results will not use any new devices.
if not args_shardings_leaves:
result_leaves = tree_util.tree_leaves(result)
result_shardings_leaves = tuple(
func._get_spec(x).sharding for x in result_leaves
)
_update_instance_devices(uid, result_shardings_leaves)
return result
def specialize(*args, **kwargs):
return make_method_wrapper(callable.specialize(*args, **kwargs))
method_wrapper = util.wraps(original_method)(method_wrapper)
method_wrapper.specialize = specialize # pyrefly: ignore[missing-attribute]
return method_wrapper
method_wrapper = make_method_wrapper(callable)
return method_wrapper
def wrap_class(
cls: type[object],
cls_sourceinfo: str | None,
) -> type[object]:
class WrappedClass:
@util.wraps(cls.__init__)
def __init__(self, *init_args, **init_kwargs) -> None:
uid = self._colocated_python_uid = (
SINGLETON_INSTANCE_REGISTRY.new_instance()
)
self.func_maker = func._CachedColocatedFunctionMaker(uid)
for attr_name in dir(cls):
original_member = getattr(cls, attr_name)
if not inspect.isfunction(original_member):
continue
# WrappedClass defines lazy initialization and colocated deletion logic.
# WrappedClass is not serializable even if the original class may be
# serializable.
if attr_name in ('__init__', '__del__', '__reduce__', '__reduce_ex__'):
continue
method = _make_method(
cls,
cls_sourceinfo,
uid,
init_args,
init_kwargs,
attr_name,
original_member,
self.func_maker,
)
# TODO(hyeontaek): Support method specialization similar to function
# specialization.
setattr(self, attr_name, method)
WrappedClass.__name__ = cls.__name__
WrappedClass.__doc__ = cls.__doc__
return WrappedClass