hand
This commit is contained in:
@@ -0,0 +1,28 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .cluster import ClusterEnv as ClusterEnv
|
||||
|
||||
# Order of declaration of the cluster environments
|
||||
# will dictate the order in which they will be checked.
|
||||
# Therefore, if multiple environments are available and
|
||||
# the user did not explicitly provide the arguments
|
||||
# to :func:`jax.distributed.initialize`, the first
|
||||
# available one from the list will be picked.
|
||||
from .ompi_cluster import OmpiCluster as OmpiCluster
|
||||
from .slurm_cluster import SlurmCluster as SlurmCluster
|
||||
from .mpi4py_cluster import Mpi4pyCluster as Mpi4pyCluster
|
||||
from .cloud_tpu_cluster import GkeTpuCluster as GkeTpuCluster
|
||||
from .cloud_tpu_cluster import GceTpuCluster as GceTpuCluster
|
||||
from .k8s_cluster import K8sCluster as K8sCluster
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,259 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
import time
|
||||
from jax._src import clusters
|
||||
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# We use an arbitrarily chosen port for the coordinator since we cannot
|
||||
# rely on communication to choose one in real time.
|
||||
coordinator_port = '8482'
|
||||
|
||||
metadata_response_code_success = 200
|
||||
|
||||
def get_metadata(key):
|
||||
import requests # pytype: disable=import-error
|
||||
import time # pytype: disable=import-error
|
||||
# Based on https://github.com/tensorflow/tensorflow/pull/40317
|
||||
gce_metadata_endpoint = 'http://' + os.environ.get(
|
||||
'GCE_METADATA_IP', 'metadata.google.internal')
|
||||
|
||||
retry_count = 0
|
||||
retrySeconds = 0.500
|
||||
api_resp = None
|
||||
|
||||
while retry_count < 6:
|
||||
api_resp = requests.get(
|
||||
f'{gce_metadata_endpoint}/computeMetadata/v1/instance/attributes/{key}',
|
||||
headers={'Metadata-Flavor': 'Google'}, timeout=60)
|
||||
if api_resp.status_code == 200:
|
||||
break
|
||||
retry_count += 1
|
||||
time.sleep(retrySeconds)
|
||||
|
||||
if api_resp is None:
|
||||
raise RuntimeError(f"Getting metadata['{key}'] failed for 6 tries")
|
||||
return api_resp.text, api_resp.status_code
|
||||
|
||||
def get_tpu_env_value_from_metadata(key) -> str | None:
|
||||
metadata_value = None
|
||||
tpu_env_data = get_metadata('tpu-env')[0]
|
||||
key_value_pairs = tpu_env_data.split('\n')
|
||||
for key_value_pair in key_value_pairs:
|
||||
# Typical line is MEGASCALE_NUM_SLICES: '2'
|
||||
if ':' in key_value_pair:
|
||||
row_key, value = re.split(':', key_value_pair, 1)
|
||||
row_key = row_key.strip()
|
||||
if row_key == key:
|
||||
metadata_value = value.strip().strip("'")
|
||||
return metadata_value
|
||||
|
||||
def get_tpu_env_value(key) -> str | None:
|
||||
# First try to get the value from the environment.
|
||||
value = os.environ.get(key, None)
|
||||
if value is None:
|
||||
# If not found, try to get it from the metadata.
|
||||
value = get_tpu_env_value_from_metadata(key)
|
||||
return value
|
||||
|
||||
class BaseTpuCluster(clusters.ClusterEnv):
|
||||
|
||||
name: str = "tpu"
|
||||
|
||||
"""Abstract cluster supports both single and multislice TPU environments.
|
||||
|
||||
If MEGASCALE_COORDINATOR_ADDRESS is not set, we assume single slice topology.
|
||||
Concrete extensions of this class must implement methods for generating a list
|
||||
of within-slice workers and a within-slice process ID.
|
||||
`get_coordinator_address` must return the address of the host with
|
||||
process ID 0 (as returned by `get_process_id`), since the coordinator service
|
||||
is started on the host with process ID = 0.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def is_env_present(cls) -> bool:
|
||||
"""Override this method to return True if the environment is present."""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_coordinator_address(cls, timeout_secs: int | None, override_coordinator_port: str | None) -> str:
|
||||
# For both GCE via QueuedResources and GKE via JobSet, the
|
||||
# Megascale coordinator address is set as the host with process id = 0,
|
||||
# so can be used as the jax distributed system coordinator.
|
||||
coordinator_address = get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS')
|
||||
if not coordinator_address:
|
||||
# For both GCE (QueuedResources and TPUVM create) and GKE via Job API,
|
||||
# the workers lists are sorted by process ID so the first one can
|
||||
# be used as the jax distributed system coordinator.
|
||||
coordinator_address = cls._get_worker_list_in_slice()[0]
|
||||
coordinator_address = coordinator_address.split(':')[0]
|
||||
logger.debug("TPU Cluster using coordinator address: %s", coordinator_address)
|
||||
cls.wait_for_coordinator(coordinator_address, timeout_secs)
|
||||
port = override_coordinator_port or coordinator_port
|
||||
return f'{coordinator_address}:{port}'
|
||||
|
||||
@classmethod
|
||||
def wait_for_coordinator(cls, coordinator_address, timeout_secs):
|
||||
# The coordinator may not be up before the other hosts try to
|
||||
# communicate with it. We check for its existence with retries.
|
||||
coordinator_found = False
|
||||
max_time = time.time() + timeout_secs
|
||||
coordinator_retry_secs = 5
|
||||
while not coordinator_found and time.time() < max_time:
|
||||
try:
|
||||
socket.gethostbyname(coordinator_address)
|
||||
coordinator_found = True
|
||||
logger.debug("Found coordinator with address %s", coordinator_address)
|
||||
except socket.gaierror:
|
||||
logger.debug(
|
||||
"Failed to recognize coordinator address %s"
|
||||
" retrying...", coordinator_address
|
||||
)
|
||||
time.sleep(coordinator_retry_secs)
|
||||
if not coordinator_found:
|
||||
raise RuntimeError(f"Failed to recognize coordinator address {coordinator_address}")
|
||||
|
||||
@classmethod
|
||||
def get_process_count(cls) -> int:
|
||||
processes_per_slice = len(cls._get_worker_list_in_slice())
|
||||
num_slices = cls._get_num_slices()
|
||||
total_process_count = processes_per_slice * num_slices
|
||||
logger.debug("Total process count of %s = %s processes per slice and %s slices", total_process_count, processes_per_slice, num_slices)
|
||||
return total_process_count
|
||||
|
||||
@classmethod
|
||||
def get_process_id(cls) -> int:
|
||||
process_id_in_slice = cls._get_process_id_in_slice()
|
||||
slice_id = cls._get_slice_id()
|
||||
processes_per_slice = len(cls._get_worker_list_in_slice())
|
||||
process_id = process_id_in_slice + slice_id * processes_per_slice
|
||||
logger.debug("Process ID of %s generated by within-slice id %s and slice id %s", process_id, process_id_in_slice, slice_id)
|
||||
return process_id
|
||||
|
||||
@staticmethod
|
||||
def _get_num_slices() -> int:
|
||||
num_slices = get_tpu_env_value('MEGASCALE_NUM_SLICES')
|
||||
if not num_slices:
|
||||
return 1
|
||||
return int(num_slices)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _get_slice_id() -> int:
|
||||
slice_id = get_tpu_env_value('MEGASCALE_SLICE_ID')
|
||||
if not slice_id:
|
||||
return 0
|
||||
return int(slice_id)
|
||||
|
||||
@staticmethod
|
||||
def _get_process_id_in_slice() -> int:
|
||||
"""Returns a process ID that is unique within slice."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def _get_worker_list_in_slice() -> list[str]:
|
||||
"""Returns a list of worker endpoints/hostnames within slice."""
|
||||
raise NotImplementedError()
|
||||
|
||||
class GceTpuCluster(BaseTpuCluster):
|
||||
|
||||
name: str = "gcetpu"
|
||||
|
||||
@classmethod
|
||||
def is_env_present(cls) -> bool:
|
||||
if not running_in_cloud_tpu_vm:
|
||||
logger.debug("Did not detect cloud TPU VM")
|
||||
return False
|
||||
if os.environ.get("TPU_SKIP_MDS_QUERY") is not None:
|
||||
logger.debug("TPU_SKIP_MDS_QUERY is set to True, so it's probably not a GCE TPU cluster.")
|
||||
return False
|
||||
metadata_response, metadata_code = get_metadata('agent-worker-number')
|
||||
if metadata_code == metadata_response_code_success:
|
||||
logger.debug("Gce Tpu Cluster detected for Jax Distributed System")
|
||||
return True
|
||||
else:
|
||||
logger.debug("Did not detect Gce Tpu Cluster since agent-worker-number is not set in metadata")
|
||||
logger.debug("Metadata code: %s", metadata_code)
|
||||
logger.debug("Metadata response: %s", metadata_response)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _get_process_id_in_slice() -> int:
|
||||
return int(get_metadata('agent-worker-number')[0])
|
||||
|
||||
@staticmethod
|
||||
def _get_worker_list_in_slice() -> list[str]:
|
||||
addrs = []
|
||||
for worker in get_metadata('worker-network-endpoints')[0].split(','):
|
||||
# worker-network-endpoints can have one of two formats. In the new format,
|
||||
# it is just a list of hostnames. In the old format, it is a list of
|
||||
# name:id:ip triples.
|
||||
parts = worker.split(':')
|
||||
if len(parts) == 1:
|
||||
addrs.append(parts[0])
|
||||
elif len(parts) == 3:
|
||||
addrs.append(parts[2])
|
||||
else:
|
||||
raise ValueError(f'unsupported worker-network-endpoints format: {worker}')
|
||||
return addrs
|
||||
|
||||
class GkeTpuCluster(BaseTpuCluster):
|
||||
|
||||
name: str = "gketpu"
|
||||
|
||||
@classmethod
|
||||
def is_env_present(cls) -> bool:
|
||||
if running_in_cloud_tpu_vm and cls._get_worker_host_names_env_var() is not None:
|
||||
logger.debug("Gke Tpu Cluster detected for Jax Distributed System")
|
||||
return True
|
||||
else:
|
||||
if not running_in_cloud_tpu_vm:
|
||||
logger.debug("Did not detect cloud TPU VM")
|
||||
else:
|
||||
logger.debug("Did not detect TPU GKE cluster since neither "
|
||||
"TPU_PROCESS_ADDRESSES nor TPU_WORKER_HOSTNAMES is set.")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _get_process_id_in_slice() -> int:
|
||||
return int(str(os.environ.get('TPU_WORKER_ID')))
|
||||
|
||||
@staticmethod
|
||||
def _get_worker_host_names_env_var() -> str | None:
|
||||
"""
|
||||
Retrieves the list of worker hostnames from environment variables.
|
||||
|
||||
Checks 'TPU_PROCESS_ADDRESSES' first, then 'TPU_WORKER_HOSTNAMES'.
|
||||
Returns None if neither environment variable is set.
|
||||
"""
|
||||
worker_hostnames = os.environ.get('TPU_PROCESS_ADDRESSES', None)
|
||||
if worker_hostnames is not None:
|
||||
return worker_hostnames
|
||||
return os.environ.get('TPU_WORKER_HOSTNAMES', None)
|
||||
|
||||
@staticmethod
|
||||
def _get_worker_list_in_slice() -> list[str]:
|
||||
"""
|
||||
Returns a list of worker endpoints/hostnames within slice.
|
||||
"""
|
||||
worker_hostnames_str = str(GkeTpuCluster._get_worker_host_names_env_var())
|
||||
return worker_hostnames_str.split(',')
|
||||
@@ -0,0 +1,130 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
import os
|
||||
import logging
|
||||
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ClusterEnv:
|
||||
"""Interface for defining a cluster environment.
|
||||
|
||||
To enable auto bootstrapping (aka :func:`jax.distributed.initialize()`),
|
||||
cluster environments need to derive from :class:`ClusterEnv` and implement
|
||||
:func:`is_env_present`, :func:`get_coordinator_address`,
|
||||
:func:`get_process_count`, and :func:`get_process_id`.
|
||||
:class:`ClusterEnv` subclasses are automatically detected when imported.
|
||||
"""
|
||||
|
||||
name: str = ""
|
||||
|
||||
_cluster_types: list[type[ClusterEnv]] = []
|
||||
opt_in_only_method: bool = False # Override this in derived classes if necessary
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
cls._cluster_types.append(cls)
|
||||
|
||||
|
||||
@classmethod
|
||||
def auto_detect_unset_distributed_params(cls,
|
||||
coordinator_address: str | None,
|
||||
num_processes: int | None,
|
||||
process_id: int | None,
|
||||
local_device_ids: Sequence[int] | None,
|
||||
cluster_detection_method: str | None,
|
||||
initialization_timeout: int | None,
|
||||
) -> tuple[str | None, int | None, int | None,
|
||||
Sequence[int] | None]:
|
||||
# First, we check the spec detection method because it will ignore submitted values
|
||||
# If if succeeds.
|
||||
if cluster_detection_method is not None:
|
||||
env = next( (env for env in cls._cluster_types if env.name == cluster_detection_method), None )
|
||||
if env is None:
|
||||
logger.error(f"Automatic Distributed initialization can not proceed:"
|
||||
f" {cluster_detection_method} is not supported.")
|
||||
elif not env.is_env_present():
|
||||
logger.error(f"Automatic Distributed initialization can not proceed:"
|
||||
f" {cluster_detection_method} is supported but not functional in this environment.")
|
||||
else:
|
||||
env = next((env for env in cls._cluster_types if env.opt_in_only_method == False and env.is_env_present()), None)
|
||||
|
||||
# Above: I have wrapped the env selection in a conditional to go through
|
||||
# opt-in methods first (currently only mpi4py) but to check all possible options
|
||||
# otherwise. Passing no cluster_detection_method results in the default, original behavior.
|
||||
|
||||
if env:
|
||||
logger.debug('Initializing distributed JAX environment via %s', env.__name__)
|
||||
if coordinator_address is None:
|
||||
coordinator_port = os.environ.get("JAX_COORDINATOR_PORT")
|
||||
coordinator_address = env.get_coordinator_address(timeout_secs=initialization_timeout, override_coordinator_port=coordinator_port)
|
||||
if num_processes is None:
|
||||
num_processes = env.get_process_count()
|
||||
if process_id is None:
|
||||
process_id = env.get_process_id()
|
||||
# Never automatically set local_device_ids on TPUs
|
||||
# Defaults to single process per device if local_process_id is available.
|
||||
# This only runs if we're in a managed distributed environment.
|
||||
# Otherwise local_device_ids will remain unset,
|
||||
# which will default to all devices being visible.
|
||||
if (local_device_ids is None and not running_in_cloud_tpu_vm and
|
||||
(pid := env.get_local_process_id()) is not None):
|
||||
local_device_ids = [pid]
|
||||
else:
|
||||
logger.debug('Could not find a known environment for initializing distributed JAX. '
|
||||
'Known environments: %s', ', '.join(e.__name__ for e in cls._cluster_types))
|
||||
return (coordinator_address, num_processes, process_id, local_device_ids)
|
||||
|
||||
@classmethod
|
||||
def is_env_present(cls) -> bool:
|
||||
"""Returns True if process is running in this cluster environment.
|
||||
"""
|
||||
raise NotImplementedError("ClusterEnv subclasses must implement is_env_present")
|
||||
|
||||
@classmethod
|
||||
def get_coordinator_address(cls, timeout_secs: int | None, override_coordinator_port: str | None) -> str:
|
||||
"""Returns address and port used by JAX to bootstrap.
|
||||
|
||||
Process id 0 will open a tcp socket at "hostname:port" where
|
||||
all the processes will connect to initialize the distributed JAX service.
|
||||
The selected port needs to be free.
|
||||
:func:`get_coordinator_address` needs to return the same hostname and port on all the processes.
|
||||
|
||||
Returns:
|
||||
"hostname:port"
|
||||
"""
|
||||
raise NotImplementedError("ClusterEnv subclasses must implement get_coordinator_address")
|
||||
|
||||
@classmethod
|
||||
def get_process_count(cls) -> int:
|
||||
raise NotImplementedError("ClusterEnv subclasses must implement get_process_count")
|
||||
|
||||
@classmethod
|
||||
def get_process_id(cls) -> int:
|
||||
raise NotImplementedError("ClusterEnv subclasses must implement get_process_id")
|
||||
|
||||
@classmethod
|
||||
def get_local_process_id(cls) -> int | None:
|
||||
""" Get index of current process inside a host.
|
||||
|
||||
The method is only useful to support single device per process.
|
||||
In that case, each process will see a local device whose ID is
|
||||
the same as its local process ID.
|
||||
If None, JAX will not restrict the visible devices.
|
||||
"""
|
||||
return None
|
||||
@@ -0,0 +1,285 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from functools import cache
|
||||
from itertools import chain
|
||||
import logging
|
||||
import numpy as np
|
||||
import os
|
||||
import socket
|
||||
import time
|
||||
import textwrap
|
||||
import warnings
|
||||
from jax._src import clusters
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def retry(
|
||||
func=None,
|
||||
initial_delay=0,
|
||||
wait=np.logspace(-1, 1, 5) * np.random.rand(5),
|
||||
exceptions=Exception,
|
||||
):
|
||||
def retry_decorator(func):
|
||||
def retry_driver(*args, **kwargs):
|
||||
# Retry the function call with exponential backoff
|
||||
for i, t in enumerate(chain([initial_delay], wait)):
|
||||
logger.debug(
|
||||
f"Trying {func.__name__} in {t:.2f} seconds, attempt {i}/{len(wait)}"
|
||||
)
|
||||
time.sleep(t)
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except exceptions as e:
|
||||
if i == len(wait):
|
||||
raise RuntimeError('Retry failed with all attempts exhausted') from e
|
||||
finally:
|
||||
logger.debug(
|
||||
f"Finished {func.__name__} after {i+1} attempts"
|
||||
)
|
||||
return retry_driver
|
||||
|
||||
if func is None:
|
||||
return retry_decorator
|
||||
else:
|
||||
return retry_decorator(func)
|
||||
|
||||
|
||||
class K8sCluster(clusters.ClusterEnv):
|
||||
|
||||
# Use an arbitrarily chosen port for the coordinator since we cannot
|
||||
# rely on communication to choose one in real time.
|
||||
_coordinator_port = '55527'
|
||||
|
||||
@classmethod
|
||||
def is_env_present(cls) -> bool:
|
||||
if 'KUBERNETES_SERVICE_HOST' in os.environ:
|
||||
try:
|
||||
import kubernetes as k8s # pytype: disable=import-error
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
warnings.warn(
|
||||
'\n'.join([
|
||||
textwrap.fill(
|
||||
"Kubernetes environment detected, but the `kubernetes` package "
|
||||
"is not installed to enable automatic bootstrapping in this "
|
||||
"environment. To enable automatic bootstrapping, please install "
|
||||
"jax with the [k8s] extra. For example:"),
|
||||
" pip install jax[k8s]",
|
||||
" pip install jax[k8s,<MORE-EXTRAS...>]",
|
||||
])
|
||||
)
|
||||
return False
|
||||
|
||||
k8s.config.load_incluster_config()
|
||||
cls._core_api = k8s.client.CoreV1Api()
|
||||
cls._batch_api = k8s.client.BatchV1Api()
|
||||
cls._ApiException = k8s.client.exceptions.ApiException
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def _handle_api_exception(cls):
|
||||
try:
|
||||
yield
|
||||
except cls._ApiException as e:
|
||||
err_msg = [f"Kubernetes API Error: {e.status} - {e.reason}"]
|
||||
if e.status == 403:
|
||||
err_msg.append(textwrap.fill(
|
||||
"It appears that the Kubernetes service account (SA) associated with "
|
||||
"this job does not have the permission for pod introspection. Please "
|
||||
"either grant the default SA permission to read pod info, or create a "
|
||||
"dedicated service account with the permission and associated with "
|
||||
"the job. For an example on setting up the service account, see the "
|
||||
"example/k8s directory in the JAX repo. For more details, please refer to "
|
||||
"https://docs.jax.dev/en/latest/multi_process.html#kubernetes-example",
|
||||
width=80
|
||||
))
|
||||
raise RuntimeError('\n'.join(err_msg)) from e
|
||||
|
||||
@classmethod
|
||||
@cache
|
||||
def _namespace(cls):
|
||||
return open(
|
||||
'/var/run/secrets/kubernetes.io/serviceaccount/namespace'
|
||||
).read().strip()
|
||||
|
||||
@classmethod
|
||||
@cache
|
||||
# in case of latency for core DNS to update pod IP to etcd/API server
|
||||
@retry(exceptions=ValueError)
|
||||
def _pod(cls):
|
||||
hostname = os.getenv('HOSTNAME')
|
||||
if hostname is None:
|
||||
raise RuntimeError("expected HOSTNAME env variable to be defined")
|
||||
ip = socket.gethostbyname(hostname)
|
||||
with cls._handle_api_exception():
|
||||
[pod] = cls._core_api.list_namespaced_pod(
|
||||
namespace=cls._namespace(),
|
||||
field_selector=f'status.podIP={ip}'
|
||||
).items
|
||||
return pod
|
||||
|
||||
@classmethod
|
||||
@cache
|
||||
def _job(cls):
|
||||
with cls._handle_api_exception():
|
||||
return cls._batch_api.read_namespaced_job(
|
||||
name=cls._pod().metadata.labels['job-name'], namespace=cls._namespace()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@cache
|
||||
def _headless_svc(cls):
|
||||
with cls._handle_api_exception():
|
||||
services = cls._core_api.list_namespaced_service(cls._namespace()).items
|
||||
|
||||
pod_labels = cls._pod().metadata.labels or {}
|
||||
for svc in services:
|
||||
if svc.spec.cluster_ip == "None": # if headless service
|
||||
svc_selector = svc.spec.selector or {}
|
||||
if all(pod_labels.get(k) == v for k, v in svc_selector.items()):
|
||||
return svc
|
||||
|
||||
# returns None if no headless service targets the current pod
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@cache
|
||||
def _controller(cls):
|
||||
# https://github.com/kubernetes/apimachinery/blob/7b4292b/pkg/apis/meta/v1/types.go#L235
|
||||
# states that there cannot be more than one managing controller.
|
||||
for owner in cls._pod().metadata.owner_references:
|
||||
if owner.controller is True:
|
||||
return owner
|
||||
|
||||
raise RuntimeError(
|
||||
'Cannot automatically initialize distributed workload: '
|
||||
f'pod {cls._pod().metadata.name} does not have a controller.'
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_coordinator_address(cls, timeout_secs: int | None, override_coordinator_port: str | None) -> str:
|
||||
controller = cls._controller()
|
||||
job = cls._job()
|
||||
pod = cls._pod()
|
||||
if controller.kind == 'Job':
|
||||
# if job belongs to a jobset
|
||||
if 'jobset.sigs.k8s.io/jobset-name' in job.metadata.labels:
|
||||
coordinator_hostname = '{job_name}-0.{subdomain}'.format(
|
||||
job_name=job.metadata.name,
|
||||
subdomain=job.metadata.labels['jobset.sigs.k8s.io/jobset-name']
|
||||
)
|
||||
# if job is standalone
|
||||
else:
|
||||
# check if the job is associated with a headless service, which is
|
||||
# necessary for pods to communicate with each other
|
||||
if pod.spec.subdomain is None:
|
||||
# check if a headless service exists but not specified as subdomain
|
||||
svc = cls._headless_svc()
|
||||
err_msg = (
|
||||
"Pods within a job need a headless service in order to "
|
||||
"communicate with each other. "
|
||||
)
|
||||
if svc:
|
||||
err_msg += (
|
||||
f"A headless service '{svc.metadata.name}' is found that "
|
||||
"targets this job, but it is not specified as the job subdomain. "
|
||||
"Please add the following to the job specification: "
|
||||
)
|
||||
fix_msg = [
|
||||
"```",
|
||||
"kind: Job",
|
||||
"spec:",
|
||||
" ...",
|
||||
" template:",
|
||||
" spec:",
|
||||
f" subdomain: {svc.metadata.name}",
|
||||
"```",
|
||||
]
|
||||
else:
|
||||
err_msg += "To fix, add the following to the job specification:"
|
||||
fix_msg = [
|
||||
"```",
|
||||
"apiVersion: v1",
|
||||
"kind: Service",
|
||||
"metadata:",
|
||||
" name: jaxpods",
|
||||
"spec:",
|
||||
" publishNotReadyAddresses: true",
|
||||
" clusterIP: None",
|
||||
" selector:",
|
||||
f" job-name: {job.metadata.name}",
|
||||
"---",
|
||||
"kind: Job",
|
||||
"spec:",
|
||||
" ...",
|
||||
" template:",
|
||||
" spec:",
|
||||
" subdomain: jaxpods",
|
||||
"```",
|
||||
]
|
||||
|
||||
raise RuntimeError('\n'.join([textwrap.fill(err_msg)] + fix_msg))
|
||||
|
||||
coordinator_hostname = '{job_name}-0.{subdomain}'.format(
|
||||
job_name=job.metadata.name,
|
||||
subdomain=pod.spec.subdomain
|
||||
)
|
||||
|
||||
if timeout_secs:
|
||||
# Ensure host pod is up before trying to communicate
|
||||
# Retry in case of cached NXDOMAIN DNS failure (30 secs default)
|
||||
@retry(
|
||||
initial_delay=0.5,
|
||||
wait=np.logspace(-1, 1.5, 8) * np.random.rand(8),
|
||||
exceptions=socket.gaierror
|
||||
)
|
||||
def wait_for_host(hostname):
|
||||
socket.gethostbyname(hostname)
|
||||
|
||||
wait_for_host(coordinator_hostname)
|
||||
|
||||
port = override_coordinator_port or cls._coordinator_port
|
||||
return '{hostname}:{port}'.format(
|
||||
hostname=coordinator_hostname,
|
||||
port=port
|
||||
)
|
||||
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'In K8s, cluster automatic bootstrap only supports Job/JobSet.'
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_process_count(cls) -> int:
|
||||
# https://kubernetes.io/docs/concepts/workloads/controllers/job/#controlling-parallelism
|
||||
return cls._job().spec.parallelism
|
||||
|
||||
@classmethod
|
||||
def get_process_id(cls) -> int:
|
||||
# https://kubernetes.io/docs/concepts/workloads/controllers/job/#completion-mode
|
||||
try:
|
||||
return int(os.environ['JOB_COMPLETION_INDEX'])
|
||||
except KeyError:
|
||||
raise RuntimeError(
|
||||
'To enable automatic bootstrap in a K8s cluster, '
|
||||
'jobs must be indexed by setting `completionMode: "Indexed"`.'
|
||||
)
|
||||
@@ -0,0 +1,96 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from jax._src import clusters
|
||||
import socket
|
||||
|
||||
from importlib.util import find_spec
|
||||
|
||||
|
||||
class Mpi4pyCluster(clusters.ClusterEnv):
|
||||
|
||||
|
||||
name: str = "mpi4py"
|
||||
opt_in_only_method: bool = True
|
||||
|
||||
@classmethod
|
||||
def is_env_present(cls) -> bool:
|
||||
|
||||
# Relies on mpi4py:
|
||||
return find_spec("mpi4py") is not None
|
||||
|
||||
@classmethod
|
||||
def get_coordinator_address(cls, timeout_secs: int | None, override_coordinator_port: str | None) -> str:
|
||||
|
||||
# Using mpi4py, figure out rank 0 and it's hostname.
|
||||
# Then broadcast the hostname and port.
|
||||
|
||||
|
||||
from mpi4py import MPI # pyrefly: ignore[missing-import]
|
||||
# Get the global communicator:
|
||||
COMM_WORLD = MPI.COMM_WORLD
|
||||
|
||||
# On rank 0, get the hostname:
|
||||
|
||||
if COMM_WORLD.Get_rank() == 0:
|
||||
# Order all the hostnames, and find unique ones
|
||||
hostname = socket.gethostname()
|
||||
|
||||
if override_coordinator_port:
|
||||
port_id = override_coordinator_port
|
||||
else:
|
||||
# Apparently, we want to pick a port in an ephemeral range...
|
||||
port_id = str(hash(hostname) % 2**12 + (65535 - 2**12 + 1))
|
||||
|
||||
hostname = f'{hostname}:{port_id}'
|
||||
|
||||
else:
|
||||
hostname = "None"
|
||||
|
||||
|
||||
|
||||
# Broadcast the host_ip to all ranks:
|
||||
hostname = COMM_WORLD.bcast(hostname, root=0)
|
||||
|
||||
|
||||
return hostname
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_process_count(cls) -> int:
|
||||
from mpi4py import MPI # pytype: disable=import-error
|
||||
return int(MPI.COMM_WORLD.Get_size())
|
||||
|
||||
@classmethod
|
||||
def get_process_id(cls) -> int:
|
||||
from mpi4py import MPI # pytype: disable=import-error
|
||||
return int(MPI.COMM_WORLD.Get_rank())
|
||||
|
||||
@classmethod
|
||||
def get_local_process_id(cls) -> int | None:
|
||||
|
||||
# Using mpi4py, split the global communicator into sub communicators
|
||||
# based on hostname. mpi will assign them ranks and that will allow
|
||||
# a selection of the local process ID.
|
||||
from mpi4py import MPI # pytype: disable=import-error
|
||||
COMM_WORLD = MPI.COMM_WORLD
|
||||
|
||||
# This is the alternative method that is simpler:
|
||||
new_comm = COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED)
|
||||
|
||||
|
||||
# The rank in the new communicator - which is host-local only - IS the local rank:
|
||||
return int(new_comm.Get_rank())
|
||||
@@ -0,0 +1,66 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
from jax._src import clusters
|
||||
|
||||
# OMPI_MCA_orte_hnp_uri exists only when processes are launched via mpirun or mpiexec
|
||||
_ORTE_URI = 'OMPI_MCA_orte_hnp_uri'
|
||||
_PROCESS_COUNT = 'OMPI_COMM_WORLD_SIZE'
|
||||
_PROCESS_ID = 'OMPI_COMM_WORLD_RANK'
|
||||
_LOCAL_PROCESS_ID = 'OMPI_COMM_WORLD_LOCAL_RANK'
|
||||
|
||||
class OmpiCluster(clusters.ClusterEnv):
|
||||
|
||||
name: str = "ompi"
|
||||
|
||||
@classmethod
|
||||
def is_env_present(cls) -> bool:
|
||||
return _ORTE_URI in os.environ
|
||||
|
||||
@classmethod
|
||||
def get_coordinator_address(cls, timeout_secs: int | None, override_coordinator_port: str | None) -> str:
|
||||
# Examples of orte_uri:
|
||||
# 1531576320.0;tcp://10.96.0.1,10.148.0.1,10.108.0.1:34911
|
||||
# 1314521088.0;tcp6://[fe80::b9b:ac5d:9cf0:b858,2620:10d:c083:150e::3000:2]:43370
|
||||
orte_uri = os.environ[_ORTE_URI]
|
||||
if override_coordinator_port:
|
||||
port = override_coordinator_port
|
||||
else:
|
||||
job_id_str = orte_uri.split('.', maxsplit=1)[0]
|
||||
# The jobid is always a multiple of 2^12, let's divide it by 2^12
|
||||
# to reduce likelihood of port conflict between jobs
|
||||
job_id = int(job_id_str) // 2**12
|
||||
# Pick port in ephemeral range [(65535 - 2^12 + 1), 65535]
|
||||
port = str(job_id % 2**12 + (65535 - 2**12 + 1))
|
||||
launcher_ip_match = re.search(r"tcp://(.+?)[,:]|tcp6://\[(.+?)[,\]]", orte_uri)
|
||||
if launcher_ip_match is None:
|
||||
raise RuntimeError('Could not parse coordinator IP address from Open MPI environment.')
|
||||
launcher_ip = next(i for i in launcher_ip_match.groups() if i is not None)
|
||||
return f'{launcher_ip}:{port}'
|
||||
|
||||
@classmethod
|
||||
def get_process_count(cls) -> int:
|
||||
return int(os.environ[_PROCESS_COUNT])
|
||||
|
||||
@classmethod
|
||||
def get_process_id(cls) -> int:
|
||||
return int(os.environ[_PROCESS_ID])
|
||||
|
||||
@classmethod
|
||||
def get_local_process_id(cls) -> int | None:
|
||||
return int(os.environ[_LOCAL_PROCESS_ID])
|
||||
@@ -0,0 +1,70 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from jax._src import clusters
|
||||
|
||||
_JOBID_PARAM = 'SLURM_JOB_ID'
|
||||
_NODE_LIST = 'SLURM_STEP_NODELIST'
|
||||
_PROCESS_COUNT = 'SLURM_NTASKS'
|
||||
_PROCESS_ID = 'SLURM_PROCID'
|
||||
_LOCAL_PROCESS_ID = 'SLURM_LOCALID'
|
||||
_NUM_NODES = 'SLURM_STEP_NUM_NODES'
|
||||
|
||||
class SlurmCluster(clusters.ClusterEnv):
|
||||
|
||||
name: str = "slurm"
|
||||
|
||||
@classmethod
|
||||
def is_env_present(cls) -> bool:
|
||||
return all(var in os.environ for var in
|
||||
(_JOBID_PARAM, _NODE_LIST, _PROCESS_COUNT, _PROCESS_ID, _LOCAL_PROCESS_ID))
|
||||
|
||||
@classmethod
|
||||
def get_coordinator_address(cls, timeout_secs: int | None, override_coordinator_port: str | None) -> str:
|
||||
if override_coordinator_port:
|
||||
port = override_coordinator_port
|
||||
else:
|
||||
# Pick port in ephemeral range [(65535 - 2^12 + 1), 65535]
|
||||
port = str(int(os.environ[_JOBID_PARAM]) % 2**12 + (65535 - 2**12 + 1))
|
||||
|
||||
# Parse the first hostname of the job
|
||||
# If we are looking for 'node001',
|
||||
# node_list potential formats are 'node001', 'node001,host2',
|
||||
# 'node[001-0015],host2', and 'node[001,007-015],host2'.
|
||||
node_list = os.environ[_NODE_LIST]
|
||||
delims = {',', '['}
|
||||
ind = next((i for i, ch in enumerate(node_list) if ch in delims), len(node_list))
|
||||
if ind == len(node_list) or node_list[ind] == ',': # Formats: 'node001' or 'node001,host2'
|
||||
return f'{node_list[:ind]}:{port}'
|
||||
else: # Formats: 'node[001-0015],host2' or 'node[001,007-015],host2'
|
||||
prefix = node_list[:ind]
|
||||
suffix = node_list[ind+1:]
|
||||
delims2 = {',', '-'}
|
||||
ind2 = next((i for i, ch in enumerate(suffix) if ch in delims2), None)
|
||||
return f'{prefix}{suffix[:ind2]}:{port}'
|
||||
|
||||
@classmethod
|
||||
def get_process_count(cls) -> int:
|
||||
return int(os.environ[_PROCESS_COUNT])
|
||||
|
||||
@classmethod
|
||||
def get_process_id(cls) -> int:
|
||||
return int(os.environ[_PROCESS_ID])
|
||||
|
||||
@classmethod
|
||||
def get_local_process_id(cls) -> int | None:
|
||||
return int(os.environ[_LOCAL_PROCESS_ID])
|
||||
Reference in New Issue
Block a user