hand
This commit is contained in:
@@ -0,0 +1,374 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import enum
|
||||
import hashlib
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import cast as type_cast
|
||||
|
||||
from jax._src import config
|
||||
from jax._src.lib import jaxlib_extension_version
|
||||
from jax._src.lib import version_str as jaxlib_version_str
|
||||
from jax._src.lib import _jax
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir import passmanager as pm
|
||||
import numpy as np
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_extra_flag_prefixes: list[str] = []
|
||||
|
||||
def add_flag_prefixes(flag_prefixes: list[str]) -> None:
|
||||
"""Add flag prefixes to include in the cache key. Call prior to get().
|
||||
"""
|
||||
global _extra_flag_prefixes
|
||||
_extra_flag_prefixes += flag_prefixes
|
||||
|
||||
|
||||
def clear_flag_prefixes() -> None:
|
||||
"""Clear flag prefixes added by add_flag_prefixes().
|
||||
"""
|
||||
global _extra_flag_prefixes
|
||||
_extra_flag_prefixes = []
|
||||
|
||||
|
||||
def get_flag_prefixes() -> list[str]:
|
||||
"""Return flag prefixes added by add_flag_prefixes().
|
||||
"""
|
||||
return _extra_flag_prefixes
|
||||
|
||||
|
||||
def custom_hook() -> str:
|
||||
"""Custom hook for any addition to the cache key.
|
||||
|
||||
The custom hook will be called every time get() is called and can be
|
||||
defined to return a string that will be hashed into the cache key.
|
||||
"""
|
||||
return ""
|
||||
|
||||
|
||||
class IgnoreCallbacks(enum.IntEnum):
|
||||
# Do not remove any callback pointers from precompiled IR.
|
||||
NO = enum.auto()
|
||||
# Remove all callback pointers from precompiled IR.
|
||||
ALL = enum.auto()
|
||||
# Remove only custom_partitioning callback pointer from precompiled IR.
|
||||
CUSTOM_PARTITIONING = enum.auto()
|
||||
|
||||
|
||||
def get(
|
||||
module: ir.Module,
|
||||
devices: np.ndarray,
|
||||
compile_options: xla_client.CompileOptions,
|
||||
backend: xla_client.Client,
|
||||
compression_algorithm: str = "zstandard",
|
||||
ignore_callbacks: IgnoreCallbacks = IgnoreCallbacks.NO,
|
||||
) -> str:
|
||||
"""Creates a hashed string to use as a key to the compilation cache.
|
||||
|
||||
Creates a cache key that is a hex-encoded string of a unique hash based on
|
||||
the arguments. The hex-encoded string is 256 characters long.
|
||||
|
||||
Args:
|
||||
module: the input program
|
||||
devices: an array of accelerator devices that the program will run on
|
||||
compile_options: options passed to the XLA compiler
|
||||
backend: description of the platform (e.g., TPU version)
|
||||
compression_algorithm: a string representing the compression algorithm used
|
||||
for the executable before persisting in the cache
|
||||
ignore_callbacks: whether to remove the all callback pointer from the
|
||||
computation.
|
||||
|
||||
Typical return value example:
|
||||
'jit__psum-14ac577cdb2ef6d986078b4054cc9893a9a14a16dbb0d8f37b89167c1f1aacdf'
|
||||
"""
|
||||
entries = [
|
||||
(
|
||||
"computation",
|
||||
lambda hash_obj: _hash_computation(
|
||||
hash_obj, module, ignore_callbacks
|
||||
),
|
||||
),
|
||||
(
|
||||
"jax_lib version",
|
||||
lambda hash_obj: hash_obj.update(
|
||||
bytes(jaxlib_version_str.encode("utf-8"))
|
||||
),
|
||||
),
|
||||
(
|
||||
"backend version",
|
||||
lambda hash_obj: _hash_platform(hash_obj, backend)
|
||||
),
|
||||
(
|
||||
"XLA flags",
|
||||
lambda hash_obj: _hash_xla_flags(hash_obj, get_flag_prefixes()),
|
||||
),
|
||||
(
|
||||
"compile_options",
|
||||
lambda hash_obj: _hash_serialized_compile_options(
|
||||
hash_obj,
|
||||
compile_options,
|
||||
# In case of GPU multi-process tasks we need to strip device
|
||||
# assignment to use cache key as invariant between processes.
|
||||
strip_device_assignment=(backend.platform == "gpu"),
|
||||
),
|
||||
),
|
||||
(
|
||||
"accelerator_config",
|
||||
lambda hash_obj: _hash_accelerator_config(hash_obj, devices),
|
||||
),
|
||||
(
|
||||
"compression",
|
||||
lambda hash_obj: _hash_string(hash_obj, compression_algorithm),
|
||||
),
|
||||
("custom_hook", lambda hash_obj: _hash_string(hash_obj, custom_hook())),
|
||||
]
|
||||
|
||||
hash_obj = hashlib.sha256()
|
||||
for name, hashfn in entries:
|
||||
hashfn(hash_obj)
|
||||
_log_cache_key_hash(hash_obj, name, hashfn)
|
||||
sym_name = module.operation.attributes['sym_name']
|
||||
module_name = ir.StringAttr(sym_name).value
|
||||
return module_name + "-" + hash_obj.digest().hex()
|
||||
|
||||
|
||||
def _log_cache_key_hash(hash_obj, last_serialized: str, hashfn):
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
# Log the hash of just this entry
|
||||
fresh_hash_obj = hashlib.sha256()
|
||||
hashfn(fresh_hash_obj)
|
||||
logger.debug(
|
||||
"get_cache_key hash of serialized %s: %s",
|
||||
last_serialized,
|
||||
fresh_hash_obj.digest().hex(),
|
||||
)
|
||||
# Log the cumulative hash
|
||||
logger.debug(
|
||||
"get_cache_key hash after serializing %s: %s",
|
||||
last_serialized,
|
||||
hash_obj.digest().hex(),
|
||||
)
|
||||
|
||||
|
||||
def _remove_callbacks(m: ir.Module, ignore_callbacks: IgnoreCallbacks):
|
||||
"""Removes callback pointers from precompiled IR.
|
||||
|
||||
Python function pointers are not deterministic across executions.
|
||||
"""
|
||||
def _update_bc_attribute(op: ir.Operation) -> ir.WalkResult:
|
||||
if "call_target_name" not in op.attributes:
|
||||
return ir.WalkResult.ADVANCE
|
||||
call_target_name = op.attributes["call_target_name"]
|
||||
assert isinstance(call_target_name, ir.StringAttr)
|
||||
if op.name == "stablehlo.custom_call" and (
|
||||
(
|
||||
ignore_callbacks == IgnoreCallbacks.ALL
|
||||
and call_target_name.value.endswith("callback")
|
||||
)
|
||||
or call_target_name.value == "CustomSPMDPartitioning"
|
||||
):
|
||||
op.attributes["backend_config"] = ir.StringAttr.get("REMOVED")
|
||||
return ir.WalkResult.ADVANCE
|
||||
|
||||
if ignore_callbacks == IgnoreCallbacks.NO:
|
||||
return m
|
||||
|
||||
m.operation.walk(_update_bc_attribute)
|
||||
return m
|
||||
|
||||
|
||||
def _serialize_ir(m: ir.Module, ignore_callbacks: IgnoreCallbacks) -> bytes:
|
||||
output = io.BytesIO()
|
||||
if ignore_callbacks != IgnoreCallbacks.NO:
|
||||
m = _remove_callbacks(
|
||||
type_cast(ir.Module, m.operation.clone()), ignore_callbacks
|
||||
)
|
||||
m.operation.write_bytecode(file=output)
|
||||
return output.getvalue()
|
||||
|
||||
|
||||
def _canonicalize_ir(
|
||||
m_original: ir.Module, ignore_callbacks: IgnoreCallbacks
|
||||
) -> bytes:
|
||||
with m_original.context:
|
||||
m = type_cast(ir.Module, m_original.operation.clone())
|
||||
passes = pm.PassManager.parse(
|
||||
"builtin.module(strip-debuginfo)"
|
||||
)
|
||||
passes.run(m.operation)
|
||||
return _serialize_ir(m, ignore_callbacks)
|
||||
|
||||
|
||||
def _hash_computation(hash_obj, module, ignore_callbacks: IgnoreCallbacks):
|
||||
if config.compilation_cache_include_metadata_in_key.value:
|
||||
canonical_ir = _serialize_ir(module, ignore_callbacks)
|
||||
else:
|
||||
canonical_ir = _canonicalize_ir(module, ignore_callbacks)
|
||||
hash_obj.update(canonical_ir)
|
||||
|
||||
|
||||
def _hash_devices(hash_obj, devices: np.ndarray) -> None:
|
||||
for device in devices.flat:
|
||||
_hash_string(hash_obj, device.device_kind)
|
||||
|
||||
|
||||
def _hash_accelerator_config(hash_obj, accelerators: np.ndarray):
|
||||
accelerator_devices = []
|
||||
for accelerator in accelerators.flat:
|
||||
accelerator_devices.append(accelerator)
|
||||
try:
|
||||
topology = xla_client.get_topology_for_devices(accelerator_devices)
|
||||
hash_obj.update(
|
||||
topology.fingerprint().to_bytes(8, byteorder="big")
|
||||
if jaxlib_extension_version >= 423
|
||||
else topology.serialize() # pyrefly: ignore[not-callable]
|
||||
)
|
||||
except _jax.JaxRuntimeError as ex:
|
||||
# Fall back for those backends that do not support serialized
|
||||
# PjRtTopologyDescription as yet.
|
||||
logger.info("get (_hash_accelerator_config): unable to hash "
|
||||
"accelerator config, falling back to hashing "
|
||||
"devices %s (type %s)", ex, type(ex))
|
||||
_hash_devices(hash_obj, accelerators)
|
||||
|
||||
# LINT.IfChange(xla_flags)
|
||||
xla_flags_to_exclude_from_cache_key = [
|
||||
"--xla_dump_compress_protos",
|
||||
"--xla_dump_module_metadata",
|
||||
"--xla_dump_max_hlo_modules",
|
||||
"--xla_dump_include_timestamp",
|
||||
"--xla_dump_hlo_pass_re",
|
||||
"--xla_dump_hlo_module_re",
|
||||
"--xla_dump_hlo_snapshots",
|
||||
"--xla_dump_fusion_visualization",
|
||||
"--xla_dump_hlo_as_url",
|
||||
"--xla_dump_hlo_as_proto",
|
||||
"--xla_dump_hlo_as_text",
|
||||
"--xla_dump_hlo_as_long_text",
|
||||
"--xla_dump_hlo_as_html",
|
||||
"--xla_dump_hlo_as_dot",
|
||||
"--xla_dump_to",
|
||||
"--xla_force_host_platform_device_count",
|
||||
"--xla_dump_disable_metadata",
|
||||
"--xla_dump_hlo_pipeline_re",
|
||||
"--xla_tpu_sdc_checker_streamz_metric",
|
||||
"--xla_tpu_sdc_checker_enable_sdc_event_callbacks",
|
||||
"--xla_tpu_sdc_checker_enable_coresweep_ng_callbacks",
|
||||
"--xla_tpu_sdc_checker_no_logging_if_callbacks_are_present",
|
||||
"--xla_gpu_cuda_data_dir",
|
||||
"--xla_gpu_experimental_autotune_cache_mode",
|
||||
]
|
||||
|
||||
env_override_flags_to_exclude_from_cache_key = {
|
||||
x.strip("-") for x in xla_flags_to_exclude_from_cache_key
|
||||
}
|
||||
# LINT.ThenChange(:debug_options)
|
||||
|
||||
def _hash_serialized_compile_options(hash_obj, compile_options_obj,
|
||||
strip_device_assignment=False):
|
||||
# Do not mess with the original CompileOptions object since it is passed to
|
||||
# the compiler. Create a deep copy for the purpose of cache key generation.
|
||||
compile_options_copy = copy.deepcopy(compile_options_obj)
|
||||
|
||||
# Certain debug options do not affect the compile result and thus, should not
|
||||
# be part of the cache key as their inclusion will result in unnecessary cache
|
||||
# misses. Clear them here by setting bool values to False, ints to 0, and
|
||||
# strings to empty. The exact values used to clear are not relevant as long
|
||||
# as the same values are used every time for each field.
|
||||
debug_options = compile_options_copy.executable_build_options.debug_options
|
||||
# LINT.IfChange(debug_options)
|
||||
debug_options.xla_force_host_platform_device_count = 0
|
||||
debug_options.xla_dump_to = ""
|
||||
debug_options.xla_dump_hlo_module_re = ""
|
||||
debug_options.xla_dump_hlo_pass_re = ""
|
||||
debug_options.xla_dump_hlo_as_text = False
|
||||
debug_options.xla_dump_hlo_as_proto = False
|
||||
debug_options.xla_dump_hlo_as_dot = False
|
||||
debug_options.xla_dump_hlo_as_url = False
|
||||
debug_options.xla_dump_hlo_as_html = False
|
||||
debug_options.xla_dump_fusion_visualization = False
|
||||
debug_options.xla_dump_hlo_snapshots = False
|
||||
debug_options.xla_dump_max_hlo_modules = False
|
||||
debug_options.xla_dump_module_metadata = False
|
||||
debug_options.xla_dump_compress_protos = False
|
||||
debug_options.xla_dump_hlo_as_long_text = False
|
||||
debug_options.xla_dump_disable_metadata = False
|
||||
debug_options.xla_dump_hlo_pipeline_re = ""
|
||||
debug_options.xla_gpu_experimental_autotune_cache_mode = 0
|
||||
|
||||
# Optional way to specify the cuda install path to be used by the compiler.
|
||||
# This could possibly affect the cuda version compiled with, but this should
|
||||
# already be included in the platform information (and might not be reflected
|
||||
# by the cuda path regardless, since this only hashes on the directory name
|
||||
# and not the contents). It can also cause spurious cache misses if the cuda
|
||||
# path changes across runs despite being the same version, so we clear it
|
||||
# here.
|
||||
debug_options.xla_gpu_cuda_data_dir = ""
|
||||
# LINT.ThenChange(:xla_flags)
|
||||
|
||||
compile_options_copy.env_option_overrides = [
|
||||
flag_value
|
||||
for flag_value in compile_options_copy.env_option_overrides
|
||||
if flag_value[0] not in env_override_flags_to_exclude_from_cache_key
|
||||
]
|
||||
if strip_device_assignment and compile_options_copy.device_assignment:
|
||||
replica_count = compile_options_copy.device_assignment.replica_count()
|
||||
computation_count = compile_options_copy.device_assignment.computation_count()
|
||||
compile_options_copy.device_assignment = xla_client.DeviceAssignment.create(
|
||||
np.arange(replica_count * computation_count).reshape( # pyrefly: ignore[bad-argument-type]
|
||||
[replica_count, computation_count])
|
||||
)
|
||||
return hash_obj.update(compile_options_copy.SerializeAsString())
|
||||
|
||||
|
||||
def _hash_platform(hash_obj, backend):
|
||||
_hash_string(hash_obj, backend.platform)
|
||||
_hash_string(hash_obj, backend.platform_version)
|
||||
|
||||
|
||||
def _hash_xla_flags(hash_obj, extra_flag_prefixes: list[str]):
|
||||
xla_flags = []
|
||||
|
||||
xla_flags_env_var = os.getenv("XLA_FLAGS")
|
||||
if xla_flags_env_var:
|
||||
xla_flags.extend(xla_flags_env_var.split())
|
||||
libtpu_init_args_env_var = os.getenv("LIBTPU_INIT_ARGS")
|
||||
if libtpu_init_args_env_var:
|
||||
xla_flags.extend(libtpu_init_args_env_var.split())
|
||||
|
||||
for arg in sys.argv:
|
||||
if arg.startswith("--xla") or any(
|
||||
arg.startswith(p) for p in extra_flag_prefixes
|
||||
):
|
||||
xla_flags.append(arg)
|
||||
|
||||
# N.B. all XLA flags that take an argument must use '=' and not a space
|
||||
# (e.g. --xla_force_host_platform_device_count=8) (I think).
|
||||
for flag in sorted(xla_flags):
|
||||
if flag.split("=")[0] in xla_flags_to_exclude_from_cache_key:
|
||||
logger.debug("Not including XLA flag in cache key: %s", flag)
|
||||
continue
|
||||
logger.debug("Including XLA flag in cache key: %s", flag)
|
||||
_hash_string(hash_obj, flag)
|
||||
|
||||
|
||||
def _hash_string(hash_obj, str_var):
|
||||
hash_obj.update(str_var.encode("utf-8").strip())
|
||||
Reference in New Issue
Block a user