This commit is contained in:
2026-05-06 19:47:31 +07:00
parent 94d8682530
commit 12dbb7731b
9963 changed files with 2747894 additions and 0 deletions
@@ -0,0 +1,28 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .cluster import ClusterEnv as ClusterEnv
# Order of declaration of the cluster environments
# will dictate the order in which they will be checked.
# Therefore, if multiple environments are available and
# the user did not explicitly provide the arguments
# to :func:`jax.distributed.initialize`, the first
# available one from the list will be picked.
from .ompi_cluster import OmpiCluster as OmpiCluster
from .slurm_cluster import SlurmCluster as SlurmCluster
from .mpi4py_cluster import Mpi4pyCluster as Mpi4pyCluster
from .cloud_tpu_cluster import GkeTpuCluster as GkeTpuCluster
from .cloud_tpu_cluster import GceTpuCluster as GceTpuCluster
from .k8s_cluster import K8sCluster as K8sCluster
@@ -0,0 +1,259 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
import os
import re
import socket
import time
from jax._src import clusters
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm
logger = logging.getLogger(__name__)
# We use an arbitrarily chosen port for the coordinator since we cannot
# rely on communication to choose one in real time.
coordinator_port = '8482'
metadata_response_code_success = 200
def get_metadata(key):
import requests # pytype: disable=import-error
import time # pytype: disable=import-error
# Based on https://github.com/tensorflow/tensorflow/pull/40317
gce_metadata_endpoint = 'http://' + os.environ.get(
'GCE_METADATA_IP', 'metadata.google.internal')
retry_count = 0
retrySeconds = 0.500
api_resp = None
while retry_count < 6:
api_resp = requests.get(
f'{gce_metadata_endpoint}/computeMetadata/v1/instance/attributes/{key}',
headers={'Metadata-Flavor': 'Google'}, timeout=60)
if api_resp.status_code == 200:
break
retry_count += 1
time.sleep(retrySeconds)
if api_resp is None:
raise RuntimeError(f"Getting metadata['{key}'] failed for 6 tries")
return api_resp.text, api_resp.status_code
def get_tpu_env_value_from_metadata(key) -> str | None:
metadata_value = None
tpu_env_data = get_metadata('tpu-env')[0]
key_value_pairs = tpu_env_data.split('\n')
for key_value_pair in key_value_pairs:
# Typical line is MEGASCALE_NUM_SLICES: '2'
if ':' in key_value_pair:
row_key, value = re.split(':', key_value_pair, 1)
row_key = row_key.strip()
if row_key == key:
metadata_value = value.strip().strip("'")
return metadata_value
def get_tpu_env_value(key) -> str | None:
# First try to get the value from the environment.
value = os.environ.get(key, None)
if value is None:
# If not found, try to get it from the metadata.
value = get_tpu_env_value_from_metadata(key)
return value
class BaseTpuCluster(clusters.ClusterEnv):
name: str = "tpu"
"""Abstract cluster supports both single and multislice TPU environments.
If MEGASCALE_COORDINATOR_ADDRESS is not set, we assume single slice topology.
Concrete extensions of this class must implement methods for generating a list
of within-slice workers and a within-slice process ID.
`get_coordinator_address` must return the address of the host with
process ID 0 (as returned by `get_process_id`), since the coordinator service
is started on the host with process ID = 0.
"""
@classmethod
def is_env_present(cls) -> bool:
"""Override this method to return True if the environment is present."""
return False
@classmethod
def get_coordinator_address(cls, timeout_secs: int | None, override_coordinator_port: str | None) -> str:
# For both GCE via QueuedResources and GKE via JobSet, the
# Megascale coordinator address is set as the host with process id = 0,
# so can be used as the jax distributed system coordinator.
coordinator_address = get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS')
if not coordinator_address:
# For both GCE (QueuedResources and TPUVM create) and GKE via Job API,
# the workers lists are sorted by process ID so the first one can
# be used as the jax distributed system coordinator.
coordinator_address = cls._get_worker_list_in_slice()[0]
coordinator_address = coordinator_address.split(':')[0]
logger.debug("TPU Cluster using coordinator address: %s", coordinator_address)
cls.wait_for_coordinator(coordinator_address, timeout_secs)
port = override_coordinator_port or coordinator_port
return f'{coordinator_address}:{port}'
@classmethod
def wait_for_coordinator(cls, coordinator_address, timeout_secs):
# The coordinator may not be up before the other hosts try to
# communicate with it. We check for its existence with retries.
coordinator_found = False
max_time = time.time() + timeout_secs
coordinator_retry_secs = 5
while not coordinator_found and time.time() < max_time:
try:
socket.gethostbyname(coordinator_address)
coordinator_found = True
logger.debug("Found coordinator with address %s", coordinator_address)
except socket.gaierror:
logger.debug(
"Failed to recognize coordinator address %s"
" retrying...", coordinator_address
)
time.sleep(coordinator_retry_secs)
if not coordinator_found:
raise RuntimeError(f"Failed to recognize coordinator address {coordinator_address}")
@classmethod
def get_process_count(cls) -> int:
processes_per_slice = len(cls._get_worker_list_in_slice())
num_slices = cls._get_num_slices()
total_process_count = processes_per_slice * num_slices
logger.debug("Total process count of %s = %s processes per slice and %s slices", total_process_count, processes_per_slice, num_slices)
return total_process_count
@classmethod
def get_process_id(cls) -> int:
process_id_in_slice = cls._get_process_id_in_slice()
slice_id = cls._get_slice_id()
processes_per_slice = len(cls._get_worker_list_in_slice())
process_id = process_id_in_slice + slice_id * processes_per_slice
logger.debug("Process ID of %s generated by within-slice id %s and slice id %s", process_id, process_id_in_slice, slice_id)
return process_id
@staticmethod
def _get_num_slices() -> int:
num_slices = get_tpu_env_value('MEGASCALE_NUM_SLICES')
if not num_slices:
return 1
return int(num_slices)
@staticmethod
def _get_slice_id() -> int:
slice_id = get_tpu_env_value('MEGASCALE_SLICE_ID')
if not slice_id:
return 0
return int(slice_id)
@staticmethod
def _get_process_id_in_slice() -> int:
"""Returns a process ID that is unique within slice."""
raise NotImplementedError()
@staticmethod
def _get_worker_list_in_slice() -> list[str]:
"""Returns a list of worker endpoints/hostnames within slice."""
raise NotImplementedError()
class GceTpuCluster(BaseTpuCluster):
name: str = "gcetpu"
@classmethod
def is_env_present(cls) -> bool:
if not running_in_cloud_tpu_vm:
logger.debug("Did not detect cloud TPU VM")
return False
if os.environ.get("TPU_SKIP_MDS_QUERY") is not None:
logger.debug("TPU_SKIP_MDS_QUERY is set to True, so it's probably not a GCE TPU cluster.")
return False
metadata_response, metadata_code = get_metadata('agent-worker-number')
if metadata_code == metadata_response_code_success:
logger.debug("Gce Tpu Cluster detected for Jax Distributed System")
return True
else:
logger.debug("Did not detect Gce Tpu Cluster since agent-worker-number is not set in metadata")
logger.debug("Metadata code: %s", metadata_code)
logger.debug("Metadata response: %s", metadata_response)
return False
@staticmethod
def _get_process_id_in_slice() -> int:
return int(get_metadata('agent-worker-number')[0])
@staticmethod
def _get_worker_list_in_slice() -> list[str]:
addrs = []
for worker in get_metadata('worker-network-endpoints')[0].split(','):
# worker-network-endpoints can have one of two formats. In the new format,
# it is just a list of hostnames. In the old format, it is a list of
# name:id:ip triples.
parts = worker.split(':')
if len(parts) == 1:
addrs.append(parts[0])
elif len(parts) == 3:
addrs.append(parts[2])
else:
raise ValueError(f'unsupported worker-network-endpoints format: {worker}')
return addrs
class GkeTpuCluster(BaseTpuCluster):
name: str = "gketpu"
@classmethod
def is_env_present(cls) -> bool:
if running_in_cloud_tpu_vm and cls._get_worker_host_names_env_var() is not None:
logger.debug("Gke Tpu Cluster detected for Jax Distributed System")
return True
else:
if not running_in_cloud_tpu_vm:
logger.debug("Did not detect cloud TPU VM")
else:
logger.debug("Did not detect TPU GKE cluster since neither "
"TPU_PROCESS_ADDRESSES nor TPU_WORKER_HOSTNAMES is set.")
return False
@staticmethod
def _get_process_id_in_slice() -> int:
return int(str(os.environ.get('TPU_WORKER_ID')))
@staticmethod
def _get_worker_host_names_env_var() -> str | None:
"""
Retrieves the list of worker hostnames from environment variables.
Checks 'TPU_PROCESS_ADDRESSES' first, then 'TPU_WORKER_HOSTNAMES'.
Returns None if neither environment variable is set.
"""
worker_hostnames = os.environ.get('TPU_PROCESS_ADDRESSES', None)
if worker_hostnames is not None:
return worker_hostnames
return os.environ.get('TPU_WORKER_HOSTNAMES', None)
@staticmethod
def _get_worker_list_in_slice() -> list[str]:
"""
Returns a list of worker endpoints/hostnames within slice.
"""
worker_hostnames_str = str(GkeTpuCluster._get_worker_host_names_env_var())
return worker_hostnames_str.split(',')
@@ -0,0 +1,130 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
import os
import logging
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm
logger = logging.getLogger(__name__)
class ClusterEnv:
"""Interface for defining a cluster environment.
To enable auto bootstrapping (aka :func:`jax.distributed.initialize()`),
cluster environments need to derive from :class:`ClusterEnv` and implement
:func:`is_env_present`, :func:`get_coordinator_address`,
:func:`get_process_count`, and :func:`get_process_id`.
:class:`ClusterEnv` subclasses are automatically detected when imported.
"""
name: str = ""
_cluster_types: list[type[ClusterEnv]] = []
opt_in_only_method: bool = False # Override this in derived classes if necessary
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._cluster_types.append(cls)
@classmethod
def auto_detect_unset_distributed_params(cls,
coordinator_address: str | None,
num_processes: int | None,
process_id: int | None,
local_device_ids: Sequence[int] | None,
cluster_detection_method: str | None,
initialization_timeout: int | None,
) -> tuple[str | None, int | None, int | None,
Sequence[int] | None]:
# First, we check the spec detection method because it will ignore submitted values
# If if succeeds.
if cluster_detection_method is not None:
env = next( (env for env in cls._cluster_types if env.name == cluster_detection_method), None )
if env is None:
logger.error(f"Automatic Distributed initialization can not proceed:"
f" {cluster_detection_method} is not supported.")
elif not env.is_env_present():
logger.error(f"Automatic Distributed initialization can not proceed:"
f" {cluster_detection_method} is supported but not functional in this environment.")
else:
env = next((env for env in cls._cluster_types if env.opt_in_only_method == False and env.is_env_present()), None)
# Above: I have wrapped the env selection in a conditional to go through
# opt-in methods first (currently only mpi4py) but to check all possible options
# otherwise. Passing no cluster_detection_method results in the default, original behavior.
if env:
logger.debug('Initializing distributed JAX environment via %s', env.__name__)
if coordinator_address is None:
coordinator_port = os.environ.get("JAX_COORDINATOR_PORT")
coordinator_address = env.get_coordinator_address(timeout_secs=initialization_timeout, override_coordinator_port=coordinator_port)
if num_processes is None:
num_processes = env.get_process_count()
if process_id is None:
process_id = env.get_process_id()
# Never automatically set local_device_ids on TPUs
# Defaults to single process per device if local_process_id is available.
# This only runs if we're in a managed distributed environment.
# Otherwise local_device_ids will remain unset,
# which will default to all devices being visible.
if (local_device_ids is None and not running_in_cloud_tpu_vm and
(pid := env.get_local_process_id()) is not None):
local_device_ids = [pid]
else:
logger.debug('Could not find a known environment for initializing distributed JAX. '
'Known environments: %s', ', '.join(e.__name__ for e in cls._cluster_types))
return (coordinator_address, num_processes, process_id, local_device_ids)
@classmethod
def is_env_present(cls) -> bool:
"""Returns True if process is running in this cluster environment.
"""
raise NotImplementedError("ClusterEnv subclasses must implement is_env_present")
@classmethod
def get_coordinator_address(cls, timeout_secs: int | None, override_coordinator_port: str | None) -> str:
"""Returns address and port used by JAX to bootstrap.
Process id 0 will open a tcp socket at "hostname:port" where
all the processes will connect to initialize the distributed JAX service.
The selected port needs to be free.
:func:`get_coordinator_address` needs to return the same hostname and port on all the processes.
Returns:
"hostname:port"
"""
raise NotImplementedError("ClusterEnv subclasses must implement get_coordinator_address")
@classmethod
def get_process_count(cls) -> int:
raise NotImplementedError("ClusterEnv subclasses must implement get_process_count")
@classmethod
def get_process_id(cls) -> int:
raise NotImplementedError("ClusterEnv subclasses must implement get_process_id")
@classmethod
def get_local_process_id(cls) -> int | None:
""" Get index of current process inside a host.
The method is only useful to support single device per process.
In that case, each process will see a local device whose ID is
the same as its local process ID.
If None, JAX will not restrict the visible devices.
"""
return None
@@ -0,0 +1,285 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from contextlib import contextmanager
from functools import cache
from itertools import chain
import logging
import numpy as np
import os
import socket
import time
import textwrap
import warnings
from jax._src import clusters
logger = logging.getLogger(__name__)
def retry(
func=None,
initial_delay=0,
wait=np.logspace(-1, 1, 5) * np.random.rand(5),
exceptions=Exception,
):
def retry_decorator(func):
def retry_driver(*args, **kwargs):
# Retry the function call with exponential backoff
for i, t in enumerate(chain([initial_delay], wait)):
logger.debug(
f"Trying {func.__name__} in {t:.2f} seconds, attempt {i}/{len(wait)}"
)
time.sleep(t)
try:
return func(*args, **kwargs)
except exceptions as e:
if i == len(wait):
raise RuntimeError('Retry failed with all attempts exhausted') from e
finally:
logger.debug(
f"Finished {func.__name__} after {i+1} attempts"
)
return retry_driver
if func is None:
return retry_decorator
else:
return retry_decorator(func)
class K8sCluster(clusters.ClusterEnv):
# Use an arbitrarily chosen port for the coordinator since we cannot
# rely on communication to choose one in real time.
_coordinator_port = '55527'
@classmethod
def is_env_present(cls) -> bool:
if 'KUBERNETES_SERVICE_HOST' in os.environ:
try:
import kubernetes as k8s # pytype: disable=import-error
except (ImportError, ModuleNotFoundError):
warnings.warn(
'\n'.join([
textwrap.fill(
"Kubernetes environment detected, but the `kubernetes` package "
"is not installed to enable automatic bootstrapping in this "
"environment. To enable automatic bootstrapping, please install "
"jax with the [k8s] extra. For example:"),
" pip install jax[k8s]",
" pip install jax[k8s,<MORE-EXTRAS...>]",
])
)
return False
k8s.config.load_incluster_config()
cls._core_api = k8s.client.CoreV1Api()
cls._batch_api = k8s.client.BatchV1Api()
cls._ApiException = k8s.client.exceptions.ApiException
return True
else:
return False
@classmethod
@contextmanager
def _handle_api_exception(cls):
try:
yield
except cls._ApiException as e:
err_msg = [f"Kubernetes API Error: {e.status} - {e.reason}"]
if e.status == 403:
err_msg.append(textwrap.fill(
"It appears that the Kubernetes service account (SA) associated with "
"this job does not have the permission for pod introspection. Please "
"either grant the default SA permission to read pod info, or create a "
"dedicated service account with the permission and associated with "
"the job. For an example on setting up the service account, see the "
"example/k8s directory in the JAX repo. For more details, please refer to "
"https://docs.jax.dev/en/latest/multi_process.html#kubernetes-example",
width=80
))
raise RuntimeError('\n'.join(err_msg)) from e
@classmethod
@cache
def _namespace(cls):
return open(
'/var/run/secrets/kubernetes.io/serviceaccount/namespace'
).read().strip()
@classmethod
@cache
# in case of latency for core DNS to update pod IP to etcd/API server
@retry(exceptions=ValueError)
def _pod(cls):
hostname = os.getenv('HOSTNAME')
if hostname is None:
raise RuntimeError("expected HOSTNAME env variable to be defined")
ip = socket.gethostbyname(hostname)
with cls._handle_api_exception():
[pod] = cls._core_api.list_namespaced_pod(
namespace=cls._namespace(),
field_selector=f'status.podIP={ip}'
).items
return pod
@classmethod
@cache
def _job(cls):
with cls._handle_api_exception():
return cls._batch_api.read_namespaced_job(
name=cls._pod().metadata.labels['job-name'], namespace=cls._namespace()
)
@classmethod
@cache
def _headless_svc(cls):
with cls._handle_api_exception():
services = cls._core_api.list_namespaced_service(cls._namespace()).items
pod_labels = cls._pod().metadata.labels or {}
for svc in services:
if svc.spec.cluster_ip == "None": # if headless service
svc_selector = svc.spec.selector or {}
if all(pod_labels.get(k) == v for k, v in svc_selector.items()):
return svc
# returns None if no headless service targets the current pod
return None
@classmethod
@cache
def _controller(cls):
# https://github.com/kubernetes/apimachinery/blob/7b4292b/pkg/apis/meta/v1/types.go#L235
# states that there cannot be more than one managing controller.
for owner in cls._pod().metadata.owner_references:
if owner.controller is True:
return owner
raise RuntimeError(
'Cannot automatically initialize distributed workload: '
f'pod {cls._pod().metadata.name} does not have a controller.'
)
@classmethod
def get_coordinator_address(cls, timeout_secs: int | None, override_coordinator_port: str | None) -> str:
controller = cls._controller()
job = cls._job()
pod = cls._pod()
if controller.kind == 'Job':
# if job belongs to a jobset
if 'jobset.sigs.k8s.io/jobset-name' in job.metadata.labels:
coordinator_hostname = '{job_name}-0.{subdomain}'.format(
job_name=job.metadata.name,
subdomain=job.metadata.labels['jobset.sigs.k8s.io/jobset-name']
)
# if job is standalone
else:
# check if the job is associated with a headless service, which is
# necessary for pods to communicate with each other
if pod.spec.subdomain is None:
# check if a headless service exists but not specified as subdomain
svc = cls._headless_svc()
err_msg = (
"Pods within a job need a headless service in order to "
"communicate with each other. "
)
if svc:
err_msg += (
f"A headless service '{svc.metadata.name}' is found that "
"targets this job, but it is not specified as the job subdomain. "
"Please add the following to the job specification: "
)
fix_msg = [
"```",
"kind: Job",
"spec:",
" ...",
" template:",
" spec:",
f" subdomain: {svc.metadata.name}",
"```",
]
else:
err_msg += "To fix, add the following to the job specification:"
fix_msg = [
"```",
"apiVersion: v1",
"kind: Service",
"metadata:",
" name: jaxpods",
"spec:",
" publishNotReadyAddresses: true",
" clusterIP: None",
" selector:",
f" job-name: {job.metadata.name}",
"---",
"kind: Job",
"spec:",
" ...",
" template:",
" spec:",
" subdomain: jaxpods",
"```",
]
raise RuntimeError('\n'.join([textwrap.fill(err_msg)] + fix_msg))
coordinator_hostname = '{job_name}-0.{subdomain}'.format(
job_name=job.metadata.name,
subdomain=pod.spec.subdomain
)
if timeout_secs:
# Ensure host pod is up before trying to communicate
# Retry in case of cached NXDOMAIN DNS failure (30 secs default)
@retry(
initial_delay=0.5,
wait=np.logspace(-1, 1.5, 8) * np.random.rand(8),
exceptions=socket.gaierror
)
def wait_for_host(hostname):
socket.gethostbyname(hostname)
wait_for_host(coordinator_hostname)
port = override_coordinator_port or cls._coordinator_port
return '{hostname}:{port}'.format(
hostname=coordinator_hostname,
port=port
)
else:
raise RuntimeError(
'In K8s, cluster automatic bootstrap only supports Job/JobSet.'
)
@classmethod
def get_process_count(cls) -> int:
# https://kubernetes.io/docs/concepts/workloads/controllers/job/#controlling-parallelism
return cls._job().spec.parallelism
@classmethod
def get_process_id(cls) -> int:
# https://kubernetes.io/docs/concepts/workloads/controllers/job/#completion-mode
try:
return int(os.environ['JOB_COMPLETION_INDEX'])
except KeyError:
raise RuntimeError(
'To enable automatic bootstrap in a K8s cluster, '
'jobs must be indexed by setting `completionMode: "Indexed"`.'
)
@@ -0,0 +1,96 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from jax._src import clusters
import socket
from importlib.util import find_spec
class Mpi4pyCluster(clusters.ClusterEnv):
name: str = "mpi4py"
opt_in_only_method: bool = True
@classmethod
def is_env_present(cls) -> bool:
# Relies on mpi4py:
return find_spec("mpi4py") is not None
@classmethod
def get_coordinator_address(cls, timeout_secs: int | None, override_coordinator_port: str | None) -> str:
# Using mpi4py, figure out rank 0 and it's hostname.
# Then broadcast the hostname and port.
from mpi4py import MPI # pyrefly: ignore[missing-import]
# Get the global communicator:
COMM_WORLD = MPI.COMM_WORLD
# On rank 0, get the hostname:
if COMM_WORLD.Get_rank() == 0:
# Order all the hostnames, and find unique ones
hostname = socket.gethostname()
if override_coordinator_port:
port_id = override_coordinator_port
else:
# Apparently, we want to pick a port in an ephemeral range...
port_id = str(hash(hostname) % 2**12 + (65535 - 2**12 + 1))
hostname = f'{hostname}:{port_id}'
else:
hostname = "None"
# Broadcast the host_ip to all ranks:
hostname = COMM_WORLD.bcast(hostname, root=0)
return hostname
@classmethod
def get_process_count(cls) -> int:
from mpi4py import MPI # pytype: disable=import-error
return int(MPI.COMM_WORLD.Get_size())
@classmethod
def get_process_id(cls) -> int:
from mpi4py import MPI # pytype: disable=import-error
return int(MPI.COMM_WORLD.Get_rank())
@classmethod
def get_local_process_id(cls) -> int | None:
# Using mpi4py, split the global communicator into sub communicators
# based on hostname. mpi will assign them ranks and that will allow
# a selection of the local process ID.
from mpi4py import MPI # pytype: disable=import-error
COMM_WORLD = MPI.COMM_WORLD
# This is the alternative method that is simpler:
new_comm = COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED)
# The rank in the new communicator - which is host-local only - IS the local rank:
return int(new_comm.Get_rank())
@@ -0,0 +1,66 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
import re
from jax._src import clusters
# OMPI_MCA_orte_hnp_uri exists only when processes are launched via mpirun or mpiexec
_ORTE_URI = 'OMPI_MCA_orte_hnp_uri'
_PROCESS_COUNT = 'OMPI_COMM_WORLD_SIZE'
_PROCESS_ID = 'OMPI_COMM_WORLD_RANK'
_LOCAL_PROCESS_ID = 'OMPI_COMM_WORLD_LOCAL_RANK'
class OmpiCluster(clusters.ClusterEnv):
name: str = "ompi"
@classmethod
def is_env_present(cls) -> bool:
return _ORTE_URI in os.environ
@classmethod
def get_coordinator_address(cls, timeout_secs: int | None, override_coordinator_port: str | None) -> str:
# Examples of orte_uri:
# 1531576320.0;tcp://10.96.0.1,10.148.0.1,10.108.0.1:34911
# 1314521088.0;tcp6://[fe80::b9b:ac5d:9cf0:b858,2620:10d:c083:150e::3000:2]:43370
orte_uri = os.environ[_ORTE_URI]
if override_coordinator_port:
port = override_coordinator_port
else:
job_id_str = orte_uri.split('.', maxsplit=1)[0]
# The jobid is always a multiple of 2^12, let's divide it by 2^12
# to reduce likelihood of port conflict between jobs
job_id = int(job_id_str) // 2**12
# Pick port in ephemeral range [(65535 - 2^12 + 1), 65535]
port = str(job_id % 2**12 + (65535 - 2**12 + 1))
launcher_ip_match = re.search(r"tcp://(.+?)[,:]|tcp6://\[(.+?)[,\]]", orte_uri)
if launcher_ip_match is None:
raise RuntimeError('Could not parse coordinator IP address from Open MPI environment.')
launcher_ip = next(i for i in launcher_ip_match.groups() if i is not None)
return f'{launcher_ip}:{port}'
@classmethod
def get_process_count(cls) -> int:
return int(os.environ[_PROCESS_COUNT])
@classmethod
def get_process_id(cls) -> int:
return int(os.environ[_PROCESS_ID])
@classmethod
def get_local_process_id(cls) -> int | None:
return int(os.environ[_LOCAL_PROCESS_ID])
@@ -0,0 +1,70 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
from jax._src import clusters
_JOBID_PARAM = 'SLURM_JOB_ID'
_NODE_LIST = 'SLURM_STEP_NODELIST'
_PROCESS_COUNT = 'SLURM_NTASKS'
_PROCESS_ID = 'SLURM_PROCID'
_LOCAL_PROCESS_ID = 'SLURM_LOCALID'
_NUM_NODES = 'SLURM_STEP_NUM_NODES'
class SlurmCluster(clusters.ClusterEnv):
name: str = "slurm"
@classmethod
def is_env_present(cls) -> bool:
return all(var in os.environ for var in
(_JOBID_PARAM, _NODE_LIST, _PROCESS_COUNT, _PROCESS_ID, _LOCAL_PROCESS_ID))
@classmethod
def get_coordinator_address(cls, timeout_secs: int | None, override_coordinator_port: str | None) -> str:
if override_coordinator_port:
port = override_coordinator_port
else:
# Pick port in ephemeral range [(65535 - 2^12 + 1), 65535]
port = str(int(os.environ[_JOBID_PARAM]) % 2**12 + (65535 - 2**12 + 1))
# Parse the first hostname of the job
# If we are looking for 'node001',
# node_list potential formats are 'node001', 'node001,host2',
# 'node[001-0015],host2', and 'node[001,007-015],host2'.
node_list = os.environ[_NODE_LIST]
delims = {',', '['}
ind = next((i for i, ch in enumerate(node_list) if ch in delims), len(node_list))
if ind == len(node_list) or node_list[ind] == ',': # Formats: 'node001' or 'node001,host2'
return f'{node_list[:ind]}:{port}'
else: # Formats: 'node[001-0015],host2' or 'node[001,007-015],host2'
prefix = node_list[:ind]
suffix = node_list[ind+1:]
delims2 = {',', '-'}
ind2 = next((i for i, ch in enumerate(suffix) if ch in delims2), None)
return f'{prefix}{suffix[:ind2]}:{port}'
@classmethod
def get_process_count(cls) -> int:
return int(os.environ[_PROCESS_COUNT])
@classmethod
def get_process_id(cls) -> int:
return int(os.environ[_PROCESS_ID])
@classmethod
def get_local_process_id(cls) -> int | None:
return int(os.environ[_LOCAL_PROCESS_ID])