This commit is contained in:
2026-05-06 19:47:31 +07:00
parent 94d8682530
commit 12dbb7731b
9963 changed files with 2747894 additions and 0 deletions
@@ -0,0 +1,23 @@
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax.experimental.jax2tf.jax2tf import (
convert as convert,
eval_polymorphic_shape as eval_polymorphic_shape,
dtype_of_val as dtype_of_val,
split_to_logical_devices as split_to_logical_devices,
DisabledSafetyCheck as DisabledSafetyCheck,
PolyShape as PolyShape # TODO: deprecate
)
from jax.experimental.jax2tf.call_tf import call_tf as call_tf
@@ -0,0 +1,727 @@
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Allows JAX to call TensorFlow functions with support for autodiff.
**Experimental: please give feedback, and expect changes.**
This module introduces the function :func:`call_tf` that allows JAX to call
TensorFlow functions.
For examples and details, see
https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax.
"""
from __future__ import annotations
from collections.abc import Callable, Sequence
import dataclasses
import functools
from typing import cast, Any
from absl import logging
import jax
from jax import dlpack
from jax import dtypes
from jax import numpy as jnp
from jax import tree_util
from jax._src import ad_util
from jax._src import core
from jax._src import effects
from jax._src import literals
from jax._src import util
from jax._src.lib import _jax
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib.mlir.dialects import hlo
from jax.experimental import roofline
from jax.experimental.jax2tf import jax2tf as jax2tf_internal
from jax._src.interpreters import mlir
import ml_dtypes
import numpy as np
import tensorflow as tf
map = util.safe_map
zip = util.safe_zip
TfConcreteFunction = Any
TfVal = jax2tf_internal.TfVal
# The platforms for which to use DLPack to avoid copying (only works on GPU
# and CPU at the moment, and only for Array). For CPU we don't need
# DLPack, if we are careful.
_DLPACK_PLATFORMS = ("gpu",)
class UnspecifiedOutputShapeDtype:
pass
def call_tf(
callable_tf: Callable,
has_side_effects=True,
ordered=False,
output_shape_dtype=UnspecifiedOutputShapeDtype(),
call_tf_graph=False,
) -> Callable:
"""Calls a TensorFlow function from JAX, with support for reverse autodiff.
The ``callable_tf`` will be called with TensorFlow-compatible arguments (
numpy.ndarray, ``tf.Tensor`` or ``tf.Variable``) or pytrees thereof. The
function must return the same type of results.
If ``call_tf`` appears in a JAX staging context (:func:`jax.jit`,
or :func:`jax.pmap`, or a control-flow primitive) then
``callable_tf`` will be compiled with ``tf.function(callable_tf,
jit_compile=True)``
and the resulting XLA computation will be embedded in JAX's XLA computation.
If ``call_tf`` appears outside a JAX staging context, it will be called inline
using TensorFlow eager mode.
The ``call_tf`` supports JAX's reverse-mode autodiff, in which case the
``callable_tf`` will be differentiated using ``tf.GradientTape``. This means
that the gradient will be TensorFlow-accurate, e.g., will respect the
custom gradients that may be defined for the code in ``callable_tf``.
For an example and more details see the
`README
<https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax>`_.
Args:
callable_tf: a TensorFlow Callable that can take a pytree of TensorFlow
arguments.
has_side_effects: if True then it ensures that instances of this primitive
are not removed or replicated by JAX optimizations such as dead-code
elimination.
ordered: If true, calls are modeled as having ordered effects.
output_shape_dtype: An optional declaration of the expected shape and dtype
of the result of the called TensorFlow function. If given it will be used
during JAX tracing to form the abstract values of the results of the
`call_tf`. If not given then we form a `tf.Graph` for the called
TensorFlow function and we use the TensorFlow-inferred shapes and types.
Must be a pytree matching the structure of the nested structure returned
from the TensorFlow function, containing objects with `.shape` and
`.dtype` attributes, e.g., `jax.ShapeDtypeStruct` or `jax.Array`.
call_tf_graph: EXPERIMENTAL, DO NOT USE. We may change the name in the
future.
Returns: a JAX callable that can be invoked with JAX pytree arguments, in
op-by-op mode or in a staged context. This callable can be used with JAX's
reverse-mode autodiff (:func:`jax.grad`).
"""
@jax.custom_vjp
def make_call(*args_jax):
"""We wrap it all in `make_call` so that we can attach custom VJP."""
args_flat_jax, args_treedef = tree_util.tree_flatten(args_jax)
# Canonicalize the arguments; e.g., makes them x32 if JAX is in 32-bit mode
def canonical_arg(v):
v = v if getattr(v, "dtype", None) else np.asarray(v)
dtype = dtypes.canonicalize_dtype(v.dtype)
if dtype != v.dtype:
v = v.astype(dtype)
return v
args_flat_jax = tuple(map(canonical_arg, args_flat_jax))
def make_tensorspec(a_jax):
a_tf_dtype = jax2tf_internal._to_tf_dtype(a_jax.dtype)
a_tf_shape = [d if core.is_constant_dim(d) else None for d in getattr(a_jax, "shape", ())]
return tf.TensorSpec(a_tf_shape, a_tf_dtype)
args_flat_sig_tf = tuple(map(make_tensorspec, args_flat_jax))
if not isinstance(output_shape_dtype, UnspecifiedOutputShapeDtype):
output_shape_dtype_flat, output_shape_dtype_tree = tree_util.tree_flatten(output_shape_dtype)
output_avals = tuple(core.ShapedArray(st.shape, st.dtype) for st in output_shape_dtype_flat)
else:
output_avals, output_shape_dtype_tree = None, None
res_treedef = None # We'll store here the result treedef
res_tf_flat = None # For error reporting
# The function below will be called at least once, either in eager
# mode during jax2tf_call_tf or in graph mode during _get_concrete_function_tf()
def callable_flat_tf(*args_tf_flat: TfVal) -> Sequence[TfVal]:
args_tf = args_treedef.unflatten(args_tf_flat)
res_tf = callable_tf(*args_tf)
# b/279454591: When `callable_tf` is a tf function with zero outputs, it
# returns a `StatefulPartitionedCall` (if the function is stateful) or
# `PartitionedCall` (if the function is stateless) op instead of
# tf.Tensors. We work around this issue by replacing the output `res_tf`
# with an empty list.
if isinstance(res_tf, tf.Operation):
assert (
res_tf.type == "StatefulPartitionedCall"
or res_tf.type == "PartitionedCall"
)
t_out = res_tf.get_attr("Tout")
# t_out should be an empty list.
assert not t_out, (
"The TF function returned an unexpected result, please check its"
f" function body. res_tf = {res_tf}"
)
res_tf = t_out
nonlocal res_treedef, res_tf_flat
res_tf_flat, res_treedef_now = tree_util.tree_flatten(res_tf)
assert res_treedef is None or res_treedef == res_treedef_now, (
f"Subsequent calls had different results. Previous {res_treedef} and now {res_treedef_now}")
res_treedef = res_treedef_now
if output_avals is not None:
if res_treedef != output_shape_dtype_tree:
raise ValueError(
"The pytree of the TensorFlow function results does not match the "
"pytree of the declared output_shape_dtype:\n"
f"results pytree: {res_treedef}\noutput_shape_dtype tree: {output_shape_dtype_tree}")
assert len(output_avals) == len(res_tf_flat)
checked_res_tf_flat = [
check_tf_result(i, r_tf, r_aval)
for i, (r_tf, r_aval) in enumerate(
zip(res_tf_flat,
(output_avals
if output_avals is not None
else (None,) * len(res_tf_flat))))]
return checked_res_tf_flat
# Prepare a tf.function ahead of time, to cache the concrete functions. This
# won't be used in op-by-op execution mode.
function_flat_tf = tf.function(
callable_flat_tf, autograph=False, jit_compile=not call_tf_graph)
res_jax_flat = call_tf_p.bind(
*args_flat_jax,
# Carry the actual function such that op-by-op call can call in TF eager mode.
callable_flat_tf=callable_flat_tf,
function_flat_tf=function_flat_tf,
args_flat_sig_tf=args_flat_sig_tf,
output_avals=output_avals,
has_side_effects=has_side_effects,
ordered=ordered,
call_tf_graph=call_tf_graph,
)
# We must have called callable_flat_tf by nοw
assert res_treedef is not None
return res_treedef.unflatten(res_jax_flat)
# Define the fwd and bwd custom_vjp functions
def make_call_vjp_fwd(*args_jax):
# Return the primal arguments as the residual
return make_call(*args_jax), args_jax
def make_call_vjp_bwd(residual_jax, ct_res_jax):
args_jax = residual_jax # residual is the primal argument
def tf_vjp_fun(args_tf, ct_res_tf):
"""Invoke TF gradient."""
# TF does not like us to watch non-float vars or Nones.
def replace_non_float_or_none(arg_tf):
if arg_tf is not None and (
arg_tf.dtype.is_floating or arg_tf.dtype.is_complex
):
return arg_tf
else:
# When watched, this will be ignored. When used in results it will
# result in a floating 0. gradient, which JAX will ignore (and
# replace it with a float0)
return tf.zeros((), dtype=tf.float32)
watched_args_tf = tf.nest.map_structure(
replace_non_float_or_none, args_tf
)
with tf.GradientTape(persistent=True) as tape:
tape.watch(watched_args_tf)
res = callable_tf(*args_tf)
tf.nest.assert_same_structure(res, ct_res_tf)
dres_darg = tape.gradient(
tf.nest.map_structure(replace_non_float_or_none, res),
sources=watched_args_tf,
output_gradients=ct_res_tf,
unconnected_gradients=tf.UnconnectedGradients.ZERO,
)
dres_darg = tree_util.tree_map(
lambda x: x if x is None else tf.convert_to_tensor(x),
dres_darg,
)
# callable_tf may mutate (the structure of) args_tf, thus we check against
# watched_args_tf which should be structurally the same as the original
# args_tf.
tf.nest.assert_same_structure(dres_darg, watched_args_tf)
return dres_darg
# Use call_tf to call the VJP function
ct_args_jax = call_tf(tf_vjp_fun)(args_jax, ct_res_jax)
# We must make the float0s that JAX expects
def fix_float0(arg_jax, ct_arg_jax):
if arg_jax is None:
return None
arg_dtype = dtypes.result_type(arg_jax) # May be scalar
ct_arg_dtype = core.primal_dtype_to_tangent_dtype(arg_dtype)
if ct_arg_dtype != ct_arg_jax.dtype:
return ad_util.zeros_like_aval(core.ShapedArray(np.shape(arg_jax),
ct_arg_dtype))
return ct_arg_jax
ct_args_jax_fixed = tree_util.tree_map(fix_float0, args_jax, ct_args_jax,
is_leaf=lambda x: x is None)
return ct_args_jax_fixed
make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd)
return util.wraps(callable_tf)(make_call)
def check_tf_result(idx: int, r_tf: TfVal, r_aval: core.ShapedArray | None) -> TfVal:
# Check that the TF function returns values of expected types. This
# improves error reporting, preventing hard-to-diagnose errors downstream
try:
jax2tf_internal._tfval_to_tensor_jax_dtype(r_tf)
except Exception as e:
msg = ("The called TF function returns a result that is not "
f"convertible to JAX: {r_tf}.")
raise ValueError(msg) from e
if r_aval is None:
return r_tf
# We convert to TF type, and canonicalize to 32-bit if necessary
r_aval_dtype_tf = jax2tf_internal._to_tf_dtype(r_aval.dtype)
# Checking shapes is trickier in presence of dynamic shapes. I wish we could
# check at runtime that the returned shape matches the declared shape. I wish
# that tf.ensure_shape did this, but it can only take shapes that contain None
# not computed shapes. However, in eager mode we should be able to resolve
# the declared shapes to constants and we get better checking.
r_aval_shape_tf = jax2tf_internal._aval_to_tf_shape(r_aval)
# We do as much checking as we can here, instead of relying on tf.ensure_shape
# because the latter gives different errors in eager vs. compiled mode.
# TODO(b/279454591): This strange error is from TF. Eager function suppose
# return tf Val with concrete shape but not. Here we change exception to warn
# and bypass it. This case need revisit on TF side.
try:
_ = len(r_tf.shape)
except ValueError as e:
msg = (
"The shape check test cannot be performed because the shape of the"
"`r_tf` tensor cannot be obtained."
f"r_tf = {r_tf}, r_aval = {r_aval}"
)
msg += str(e)
logging.warning(msg)
return r_tf
if (r_tf.dtype != r_aval_dtype_tf or
len(r_tf.shape) != len(r_aval_shape_tf) or
any(r_aval_d is not None and r_tf_d is not None and r_aval_d != r_tf_d
for r_tf_d, r_aval_d in zip(r_tf.shape, r_aval_shape_tf))):
msg = ("The shapes or dtypes returned by the TensorFlow function "
"do not match the declared output_shape_dtype:\n"
f"Result[{idx}] is {r_tf.dtype}[{r_tf.shape}] vs. expected {r_aval_dtype_tf}[{r_aval_shape_tf}]")
raise ValueError(msg)
# At this point tf.ensure_shape does not do much, it should never throw an
# error, albeit it may refine the shape a bit.
return tf.ensure_shape(r_tf, r_aval_shape_tf)
call_tf_p = core.Primitive("call_tf")
call_tf_p.multiple_results = True
# The impl will be used in op-by-op mode and calls callable_tf in TF eager mode.
def _call_tf_impl(*args_jax_flat, callable_flat_tf, **_):
# On GPU we use dlpack to avoid copies of data to the host.
def _arg_jax_to_tf(arg_jax):
if (isinstance(arg_jax, jax.Array) and
list(arg_jax.devices())[0].platform in _DLPACK_PLATFORMS and
dlpack.is_supported_dtype(arg_jax.dtype)):
return tf.experimental.dlpack.from_dlpack(arg_jax.__dlpack__())
# The following avoids copies to the host on CPU, always for Array
# and even for ndarray if they are sufficiently aligned.
# TODO(necula): on TPU this copies to the host!
if getattr(arg_jax, 'dtype', None) == dtypes.float0:
return tf.zeros(shape=arg_jax.shape,
dtype=jax2tf_internal._tf_np_dtype_for_float0)
if isinstance(arg_jax, tuple(literals.typed_scalar_types)):
# Make sure to preserve the JAX dtype for TypedInt, etc.
return tf.constant(np.asarray(arg_jax, dtype=arg_jax.dtype))
return tf.constant(np.asarray(arg_jax))
args_tf_flat = tuple(map(_arg_jax_to_tf, args_jax_flat))
with jax2tf_internal.inside_call_tf():
# Call in TF eager mode
res_tf_flat = callable_flat_tf(*args_tf_flat)
def _res_tf_to_jax(res_tf: TfVal):
res_tf, jax_dtype = jax2tf_internal._tfval_to_tensor_jax_dtype(res_tf)
if isinstance(res_tf, tf.Tensor) and dlpack.is_supported_dtype(jax_dtype):
res_tf_platform = tf.DeviceSpec.from_string(res_tf.backing_device).device_type
res_jax_platform = res_tf_platform.lower()
if res_jax_platform in _DLPACK_PLATFORMS:
return jax.dlpack.from_dlpack(res_tf)
# When working with a bfloat16 scalar tf.Tensor,np.asarray() can fail.
# To handle this special case, we create a numpy copy.
if res_tf.shape == tf.TensorShape([]) and res_tf.dtype == tf.bfloat16:
return jax.device_put(jnp.array(res_tf.numpy()))
else:
return jax.device_put(np.asarray(res_tf))
return list(map(_res_tf_to_jax, res_tf_flat))
call_tf_p.def_impl(_call_tf_impl)
@functools.lru_cache(maxsize=128)
def _get_concrete_function_tf(function_flat_tf, args_flat_sig_tf): # -> tf.ConcreteFunction
with jax2tf_internal.inside_call_tf():
return function_flat_tf.get_concrete_function(*args_flat_sig_tf)
# Mark the effectful instances of call_tf
@dataclasses.dataclass(frozen=True)
class CallTfEffect(effects.Effect):
__str__ = lambda _: "CallTfEffect"
call_tf_effect = CallTfEffect()
effects.lowerable_effects.add_type(CallTfEffect)
effects.control_flow_allowed_effects.add_type(CallTfEffect)
effects.remat_allowed_effects.add_type(CallTfEffect)
effects.custom_derivatives_allowed_effects.add_type(CallTfEffect)
class CallTfOrderedEffect(effects.Effect):
__str__ = lambda _: "CallTfOrderedEffect"
call_tf_ordered_effect = CallTfOrderedEffect()
effects.lowerable_effects.add_type(CallTfOrderedEffect)
effects.control_flow_allowed_effects.add_type(CallTfOrderedEffect)
effects.remat_allowed_effects.add_type(CallTfOrderedEffect)
effects.custom_derivatives_allowed_effects.add_type(CallTfOrderedEffect)
effects.ordered_effects.add_type(CallTfOrderedEffect)
effects.shardable_ordered_effects.add_type(CallTfOrderedEffect)
def _call_tf_abstract_eval(
*args_flat_avals,
function_flat_tf,
args_flat_sig_tf,
has_side_effects,
ordered,
output_avals,
call_tf_graph,
**__,
):
# Called only when we form a Jaxpr, i.e., under jit, scan, etc.
effs: set[effects.Effect] = set()
if ordered:
effs.add(call_tf_ordered_effect)
elif has_side_effects:
effs.add(call_tf_effect)
# If no output_avals is given, then we ask TF to infer the output shapes.
# We call this even if output_avals is given because it will ensure that
# callable_flat_tf is called. Since _get_concrete_function_tf is cached
# there is a small cost of calling it more often than needed.
concrete_function_flat_tf = _get_concrete_function_tf(function_flat_tf,
args_flat_sig_tf)
# In the case that the tf.function has no return value
if len(concrete_function_flat_tf.outputs) == 0:
return (), effs
if output_avals is not None:
return output_avals, effs
def is_fully_known_shape(s):
return s.rank is not None and all(d is not None for d in s)
if all(is_fully_known_shape(s)
for s in concrete_function_flat_tf.output_shapes):
avals_from_tf = tuple(
# We convert to JAX type, and canonicalize to 32-bit if necessary
core.ShapedArray(shape, jax2tf_internal._to_jax_dtype(dtype))
for dtype, shape in zip(concrete_function_flat_tf.output_dtypes,
concrete_function_flat_tf.output_shapes))
return avals_from_tf, effs
msg = ("call_tf cannot call functions whose output has dynamic shape. "
f"Found output shapes: {concrete_function_flat_tf.output_shapes}. "
"Consider using the `output_shape_dtype` argument to call_tf. "
"\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf"
" for a discussion.")
raise ValueError(msg)
call_tf_p.def_effectful_abstract_eval(_call_tf_abstract_eval)
def _mlir_type_to_numpy_dtype(type: ir.Type) -> np.dtype:
"""Converts an MLIR scalar type to a NumPy dtype."""
if isinstance(type, ir.IntegerType):
type = ir.IntegerType(type)
width = type.width
if width == 1:
return np.dtype(np.bool_)
elif width == 8:
return np.dtype(np.uint8 if type.is_unsigned else np.int8)
elif width == 16:
return np.dtype(np.uint16 if type.is_unsigned else np.int16)
elif width == 32:
return np.dtype(np.uint32 if type.is_unsigned else np.int32)
elif width == 64:
return np.dtype(np.uint64 if type.is_unsigned else np.int64)
else:
raise ValueError(f"Unsupported integer width: {width}")
elif isinstance(type, ir.F16Type):
return np.dtype(np.float16)
elif isinstance(type, ir.F32Type):
return np.dtype(np.float32)
elif isinstance(type, ir.F64Type):
return np.dtype(np.float64)
elif isinstance(type, ir.BF16Type):
return np.dtype(ml_dtypes.bfloat16)
elif isinstance(type, ir.ComplexType):
element_type = ir.ComplexType(type).element_type
if isinstance(element_type, ir.F32Type):
return np.dtype(np.complex64)
elif isinstance(element_type, ir.F64Type):
return np.dtype(np.complex128)
else:
raise ValueError(f"Unsupported complex element type: {element_type}")
else:
raise TypeError(f"Unsupported MLIR type for NumPy conversion: {type}")
def _call_tf_lowering(
ctx: mlir.LoweringRuleContext,
*args_op,
platform,
function_flat_tf,
args_flat_sig_tf,
has_side_effects,
ordered,
call_tf_graph,
output_avals,
**_,
):
# We use the same TF lowering device as for the embedding JAX computation.
# One example when this is needed is when the code refers to variables on one
# device. Or, for sharding annotations (only supported on TPU).
if platform in ["cpu", "tpu"]:
tf_platform = platform.upper()
elif platform == "cuda":
tf_platform = "GPU"
else:
raise ValueError("platform {platform} not supported")
concrete_function_flat_tf = _get_concrete_function_tf(function_flat_tf, args_flat_sig_tf)
captured_inputs = []
if concrete_function_flat_tf.captured_inputs:
# The function uses either captured variables or tensors.
msg = (
"call_tf works best with a TensorFlow function that does not capture "
"variables or tensors from the context. "
"See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion. "
f"The following captures were found {concrete_function_flat_tf.captured_inputs}")
logging.warning(msg)
for inp in concrete_function_flat_tf.captured_inputs:
if inp.dtype == tf.resource: # A variable; lookup by handle
inp_vars = [v for v in concrete_function_flat_tf.variables if inp is v.handle]
assert len(inp_vars) == 1, f"Found {inp_vars}"
captured_inputs.append(inp_vars[0])
else:
captured_inputs.append(inp)
# The following use case happens when we call_tf a restored saved model that
# includes parameters (hence functions closing over tf.Variable), and then
# we jax2tf.convert it with native serialization, under tf.function (or
# for saving to saved model). The `np.asarray(inp)` fails because it thinks
# it is in TF graph mode. The `tf.init_scope()` lifts out of function-building
# graph scopes, and allows us to read the values of the variables
with tf.init_scope():
captured_ops = tuple(
mlir.flatten_ir_values(
mlir.ir_constant(np.asarray(inp)) for inp in captured_inputs
)
)
if call_tf_graph:
with jax2tf_internal.inside_call_tf():
return emit_tf_embedded_graph_custom_call(
ctx,
concrete_function_flat_tf,
tuple(args_op) + captured_ops,
has_side_effects,
ordered,
output_avals,
)
def convert_to_spec(x):
if isinstance(x, tf.TensorSpec):
return x
else:
return tf.TensorSpec.from_tensor(x)
args_tf_flat = [convert_to_spec(a) for a in args_flat_sig_tf]
with jax2tf_internal.inside_call_tf():
try:
func_tf_hlo = function_flat_tf.experimental_get_compiler_ir(
*args_tf_flat
)(stage="hlo_serialized", platform_name=tf_platform)
except Exception as e:
msg = ("Error compiling TensorFlow function (see below for the caught exception)." +
"\ncall_tf can used " +
"in a staged context (under jax.jit, lax.scan, etc.) only with " +
"compilable functions with static output shapes.\n" +
"See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion." +
"\n\nCaught TensorFlow exception: " + str(e))
raise ValueError(msg) from e
stablehlo = _jax.mlir.hlo_to_stablehlo(func_tf_hlo)
submodule = ir.Module.parse(stablehlo)
symtab = ir.SymbolTable(submodule.operation)
main = cast(func_dialect.FuncOp, symtab["main"])
callee_result_types = main.type.results
fn = mlir.merge_mlir_modules(ctx.module_context.module,
f"call_tf_{function_flat_tf.name}",
submodule,
dst_symtab=ctx.module_context.symbol_table)
call = func_dialect.CallOp(callee_result_types,
ir.FlatSymbolRefAttr.get(fn),
[*args_op, *captured_ops])
flat_results = call.results
if ordered:
raise NotImplementedError(
"ordered=True is not supported in the jitted context without"
" `call_tf_graph=True`"
)
outputs = []
for op, res_type in zip(flat_results, callee_result_types):
if not res_type.has_static_shape:
msg = (
"Compiled TensorFlow function has dynamic output shape "
+ f"{res_type}. call_tf can used in a staged context (under jax.jit,"
" lax.scan, etc.) only with compilable functions with static"
" output shapes. See"
" https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf"
" for a discussion."
)
raise ValueError(msg)
res_dtype = _mlir_type_to_numpy_dtype(res_type.element_type)
# Canonicalize the results; e.g., makes them x32 if JAX is in 32-bit mode
jax_res_dtype = dtypes.canonicalize_dtype(res_dtype)
if res_dtype != jax_res_dtype:
op = hlo.ConvertOp(
mlir.aval_to_ir_type(core.ShapedArray(res_type.shape, jax_res_dtype)),
op,
).result
outputs.append(op)
return outputs
def _register_call_lowering(platform):
mlir.register_lowering(call_tf_p, functools.partial(_call_tf_lowering,
platform=platform),
platform=platform)
for platform in ("cpu", "cuda", "tpu"):
_register_call_lowering(platform)
def emit_tf_embedded_graph_custom_call(
ctx: mlir.LoweringRuleContext,
concrete_function_flat_tf,
operands: Sequence[ir.Value],
has_side_effects,
ordered,
output_avals,
):
"""Emits a custom call referencing a tf.Graph embedding of the TF function.
All call_tf called function information is stored in tf.metadata.
This includes:
(1) The called function name: This name will be used by the runtime to execute
the callback.
(2) The called function index in the XLACallModule `function_list` attribute.
"""
call_tf_concrete_function_list = jax2tf_internal.get_thread_local_state_call_tf_concrete_function_list()
if call_tf_concrete_function_list is None:
raise ValueError(
"call_tf_graph=True only support exporting by jax2tf.convert currently."
)
# TODO(necula): It is dangerous to modify global state when lowering because
# there are a number of lowering caches that only cache the StableHLO.
# See call_tf_test.py:test_multi_platform_call_tf_graph.
called_index = add_to_call_tf_concrete_function_list(
concrete_function_flat_tf, call_tf_concrete_function_list)
tf_backend_config = {
"has_token_input_output": ir.BoolAttr.get(ordered),
"called_index": mlir.i64_attr(called_index),
}
result_avals = ctx.avals_out if ctx.avals_out is not None else ()
operands = list(operands)
result_types = list(
mlir.flatten_ir_types([mlir.aval_to_ir_type(aval) for aval in result_avals])
)
if ordered:
operands.insert(0, ctx.tokens_in.get(call_tf_ordered_effect))
result_types.insert(0, mlir.token_type())
custom_call = hlo.CustomCallOp(
result_types,
operands,
call_target_name=ir.StringAttr.get("tf.call_tf_function"),
has_side_effect=ir.BoolAttr.get(has_side_effects),
api_version=mlir.i32_attr(2),
called_computations=ir.ArrayAttr.get([]),
backend_config=ir.StringAttr.get(""),
)
# Store TF metadata in unregistered attribute
custom_call.attributes["tf.backend_config"] = ir.DictAttr.get(
tf_backend_config
)
results = list(custom_call.results)
if ordered:
token = results.pop(0)
ctx.set_tokens_out(mlir.TokenSet({call_tf_ordered_effect: token}))
return results
def add_to_call_tf_concrete_function_list(concrete_tf_fn: Any, call_tf_concrete_function_list: list[Any]) -> int:
try:
called_index = call_tf_concrete_function_list.index(concrete_tf_fn)
except ValueError:
called_index = len(call_tf_concrete_function_list)
call_tf_concrete_function_list.append(concrete_tf_fn)
return called_index
# Register a roofline call so that users can use roofline on functions that
# contain call_tf. We register roofline in this file (instead of within the
# roofline module) to avoid having to import jax2tf in roofline.
roofline.register_standard_roofline(call_tf_p)
@@ -0,0 +1,915 @@
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides JAX and TensorFlow interoperation APIs."""
from __future__ import annotations
from collections.abc import Callable, Sequence
from functools import partial
import contextlib
import math
import os
import threading
from typing import Any, Union
import warnings
from absl import logging
import numpy as np
import jax
from jax import tree_util
from jax import export
from jax._src import api
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import dtypes
from jax._src import op_shardings
from jax._src import source_info_util
from jax._src import util
from jax._src.export import _export
from jax._src.export import shape_poly
from jax._src.lib import xla_client
import tensorflow as tf
# These don't have public equivalents.
from tensorflow.compiler.tf2xla.python import xla as tfxla
from tensorflow.compiler.xla import xla_data_pb2
try:
from tensorflow.python.compiler.xla.experimental import xla_sharding
except ModuleNotFoundError:
# This can be removed when TF 2.10 support is no longer needed.
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
from tensorflow.python.eager import context as tf_context
NameStack = source_info_util.NameStack
PolyShape = shape_poly.PolyShape # TODO: deprecate
DType = Any
DisabledSafetyCheck = export.DisabledSafetyCheck
map = util.safe_map
zip = util.safe_zip
# A value suitable in a TF tracing context: tf.Tensor, tf.Variable,
# or Python scalar or numpy.ndarray. (A tf.EagerTensor is a tf.Tensor.)
TfVal = Any
PrecisionType = int # Enum xla_data.PrecisionConfig.Precision
def _is_tfval(v: TfVal) -> bool:
if isinstance(v, (tf.Tensor, tf.Variable)):
return True
try:
# Include all convertible types, even if not supported on accelerators.
with tf.device("CPU"):
tf.constant(v)
return True
except:
return False
class _DefaultNativeSerialization:
pass
DEFAULT_NATIVE_SERIALIZATION = _DefaultNativeSerialization()
# In order to ensure that JAX picks up the proper user-frame for source
# locations we will register the TensorFlow source path as an internal
# path with source_info_util. The typical stack when a JAX primitive
# conversion happens is:
# jax2tf.process_primitive (top of stack)
# jax tracing machinery ...
# tf.custom_gradient machinery ...
# jax2tf.converted_fun
# tf function machinery ...
# user code invokes the converted function on TF tensors
#
# We need to skip over not only JAX internal frames, but TF internal frames
# also.
# We register the TensorFlow source path lazily
_has_registered_tf_source_path = False
class _ThreadLocalState(threading.local):
def __init__(self):
# Keep track if we are inside a call_tf. In that context we disable the
# safety check that we are not inside JAX transformations.
self.inside_call_tf = False
# Maps dimension variables to TF expressions, for non-native lowering
self.shape_env: Sequence[tuple[str, TfVal]] = ()
# A dict collecting all tf concrete_functions called by stablehlo.custom_call
# This is used only by native serialization (unlike all the other
# thread-local state).
self.call_tf_concrete_function_list: list[Any] | None = None
_thread_local_state = _ThreadLocalState()
@contextlib.contextmanager
def inside_call_tf():
# Set the inside_call_tf flag for a context.
prev = _thread_local_state.inside_call_tf
_thread_local_state.inside_call_tf = True
try:
yield
finally:
_thread_local_state.inside_call_tf = prev
def get_thread_local_state_call_tf_concrete_function_list() -> (
list[Any] | None
):
return _thread_local_state.call_tf_concrete_function_list
@partial(api_util.api_hook, tag="jax2tf_convert")
def convert(fun_jax: Callable,
*,
polymorphic_shapes: str | PolyShape | None | Sequence[str | PolyShape | None] = None,
polymorphic_constraints: Sequence[str] = (),
with_gradient: bool = True,
enable_xla: bool | _DefaultNativeSerialization = DEFAULT_NATIVE_SERIALIZATION,
native_serialization: bool | _DefaultNativeSerialization = DEFAULT_NATIVE_SERIALIZATION,
native_serialization_platforms: Sequence[str] | None = None,
native_serialization_disabled_checks: Sequence[DisabledSafetyCheck] = (),
) -> Callable:
"""Allows calling a JAX function from a TensorFlow program.
See
[README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md)
for more details about usage and common problems.
Args:
fun_jax: target JAX function to be called. Its arguments and return value
should be JAX arrays, or nested standard Python containers
(tuple/list/dict) thereof (pytrees).
polymorphic_shapes: Specifies input shapes to be treated polymorphically
during lowering.
.. warning:: The shape-polymorphic lowering is an experimental feature.
It is meant to be sound, but it is known to reject some JAX programs
that are shape polymorphic. The details of this feature can change.
It should be `None` (all arguments are monomorphic), a single PolyShape
or string (applies to all arguments), or a tuple/list of the same length
as the function arguments. For each argument the shape specification
should be `None` (monomorphic argument), or a Python object with the
same pytree structure as the argument.
See [how optional parameters are matched to
arguments](https://docs.jax.dev/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).
A shape specification for an array argument should be an object
`PolyShape(dim0, dim1, ..., dimn)`
where each `dim` is a dimension specification: a positive integer denoting
a monomorphic dimension of the given size, or a string denoting a
dimension variable assumed to range over non-zero dimension sizes, or
the special placeholder string "_" denoting a monomorphic dimension
whose size is given by the actual argument. As a shortcut, an Ellipsis
suffix in the list of dimension specifications stands for a list of "_"
placeholders.
For convenience, a shape specification can also be given as a string
representation, e.g.: "batch, ...", "batch, height, width, _", possibly
with surrounding parentheses: "(batch, ...)".
The lowering fails if it cannot ensure that the it would produce the same
sequence of TF ops for any non-zero values of the dimension variables.
polymorphic_shapes are only supported for positional arguments; shape
polymorphism is not supported for keyword arguments.
See [the README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
for more details.
polymorphic_constraints: a sequence of constraints on symbolic dimension
expressions, of the form `e1 >= e2` or `e1 <= e2`.
See more details at https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
with_gradient: if set (default), add a tf.custom_gradient to the lowered
function, by converting the ``jax.vjp(fun)``. This means that reverse-mode
TensorFlow AD is supported for the output TensorFlow function, and the
value of the gradient will be JAX-accurate.
native_serialization_platforms: Specifies the platform(s)
for which to lower the code. Must be a tuple of
strings, including a subset of: 'cpu', 'cuda', 'rocm', 'tpu'.
The default (`None``), specifies the JAX default
backend on the machine where the lowering is done.
native_serialization_disabled_checks: Disables the specified safety checks.
See docstring of `DisabledSafetyCheck`.
Returns:
A version of `fun_jax` that expects TfVals as arguments (or
tuple/lists/dicts thereof), and returns TfVals as outputs, and uses
only TensorFlow ops and thus can be called from a TensorFlow program.
"""
if native_serialization is not DEFAULT_NATIVE_SERIALIZATION:
warnings.warn(
"The `native_serialization` parameter is deprecated and "
"will be removed in a future version of JAX.",
DeprecationWarning, stacklevel=2)
del native_serialization
if enable_xla is not DEFAULT_NATIVE_SERIALIZATION:
warnings.warn(
"The `enable_xla` parameter is deprecated and "
"will be removed in a future version of JAX.",
DeprecationWarning, stacklevel=2)
del enable_xla
if native_serialization_platforms:
if (not isinstance(native_serialization_platforms, (list, tuple)) or
not all(p in ["cpu", "cuda", "rocm", "tpu"]
for p in native_serialization_platforms)):
raise ValueError(
"native_serialization_platforms must be a sequence "
"containing a subset of {'cpu', 'cuda', 'rocm', 'tpu'}. "
f"Got: {native_serialization_platforms}")
native_serialization_platforms = tuple(native_serialization_platforms)
api.check_callable(fun_jax)
def converted_fun_tf(*args_tf: TfVal, **kwargs_tf: TfVal) -> TfVal:
# TODO: is there a better way to check if we are inside a transformation?
if not core.trace_state_clean() and not _thread_local_state.inside_call_tf:
# It is Ok to nest convert when we are inside a call_tf
raise ValueError(
"convert must be used outside all JAX transformations." +
f"Trace state: {core.trace_ctx}")
global _has_registered_tf_source_path
if not _has_registered_tf_source_path:
source_info_util.register_exclusion(os.path.dirname(tf.__file__))
_has_registered_tf_source_path = True
def jax_arg_spec_from_tf(a: TfVal) -> jax.ShapeDtypeStruct:
# The shape and JAX dtype for a TF argument
tf_arg_shape = np.shape(a)
# Fix the shape for TF1
tf_arg_shape = tuple(d.value
if isinstance(d, tf.compat.v1.Dimension) else d
for d in tf_arg_shape)
_, a_jax_dtype = _tfval_to_tensor_jax_dtype(a)
# We count on the fact that jax.ShapeDtypeStruct allows shapes that
# contain None.
return jax.ShapeDtypeStruct(tf_arg_shape, a_jax_dtype)
args_jax_specs = tree_util.tree_map(jax_arg_spec_from_tf, args_tf)
args_specs = export.symbolic_args_specs(
args_jax_specs, polymorphic_shapes,
constraints=polymorphic_constraints)
# The polymorphic_shapes argument refers to positional arguments only.
# We assume None for the kwargs.
kwargs_jax_specs = tree_util.tree_map(jax_arg_spec_from_tf, kwargs_tf)
kwargs_specs = export.symbolic_args_specs(
kwargs_jax_specs, None)
combined_args_tf = (args_tf, kwargs_tf)
args_flat_tf: Sequence[TfVal]
args_flat_tf, args_kwargs_tree = tree_util.tree_flatten(combined_args_tf)
args_flat_tf = tuple(
map(preprocess_arg_tf, range(len(args_flat_tf)), args_flat_tf))
impl = NativeSerializationImpl(
fun_jax,
args_specs=args_specs, kwargs_specs=kwargs_specs,
native_serialization_platforms=native_serialization_platforms,
native_serialization_disabled_checks=native_serialization_disabled_checks)
try:
impl.before_conversion()
outs_tree: tree_util.PyTreeDef | None = None
if with_gradient:
@tf.custom_gradient
def converted_fun_flat_with_custom_gradient_tf(*args_flat_tf: TfVal) -> TfVal:
nonlocal outs_tree
outs_tf, outs_avals, outs_tree = impl.run_fun_tf(args_flat_tf)
return (tuple(outs_tf),
_make_custom_gradient_fn_tf(
fun_jax,
impl=impl,
with_gradient=with_gradient,
args_specs=args_specs, kwargs_specs=kwargs_specs,
args_tf=args_flat_tf,
outs_avals=outs_avals,
outs_tf=outs_tf))
outs_flat_tf = converted_fun_flat_with_custom_gradient_tf(*args_flat_tf)
else:
outs_tf, _, outs_tree = impl.run_fun_tf(args_flat_tf)
message = ("The jax2tf-converted function does not support gradients. "
"Use `with_gradient` parameter to enable gradients")
# We use PreventGradient, which is propagated through a SavedModel.
outs_flat_tf = [
tf.raw_ops.PreventGradient(input=o, message=message)
for o in outs_tf
]
finally:
impl.after_conversion()
assert outs_tree is not None
outs_flat_tf = [tf.identity(x, "jax2tf_out") for x in outs_flat_tf]
out_tf = tree_util.tree_unflatten(outs_tree, outs_flat_tf)
return out_tf
return converted_fun_tf
class NativeSerializationImpl:
def __init__(self, fun_jax, *,
args_specs, kwargs_specs,
native_serialization_platforms: Sequence[str] | None,
native_serialization_disabled_checks: Sequence[DisabledSafetyCheck]):
self.convert_kwargs = dict(native_serialization_platforms=native_serialization_platforms,
native_serialization_disabled_checks=native_serialization_disabled_checks)
if hasattr(fun_jax, "trace"):
# If we have a pjit or pmap already we do not wrap with another, and we
# allow shardings.
fun_jit = fun_jax
else:
# We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also
# convert(f_jax), in which case a "jit" is implied. In that case we raise
# an error if the lowered function contains non-replicated sharding annotations.
fun_jit = jax.jit(fun_jax)
self.fun_jax = fun_jit
self.args_specs = args_specs
self.kwargs_specs = kwargs_specs
self.native_serialization_disabled_checks = native_serialization_disabled_checks
self.native_serialization_platforms = native_serialization_platforms
def before_conversion(self):
_prev_func_list = _thread_local_state.call_tf_concrete_function_list
_thread_local_state.call_tf_concrete_function_list = []
def _restore_context():
_thread_local_state.call_tf_concrete_function_list = _prev_func_list
self._restore_context = _restore_context
_exported_device_assignment = [None]
self.exported = _export._export_internal(
self.fun_jax,
platforms=self.native_serialization_platforms,
disabled_checks=self.native_serialization_disabled_checks,
_device_assignment_for_internal_jax2tf_use_only=_exported_device_assignment,
)(*self.args_specs, **self.kwargs_specs)
assert(_exported_device_assignment[0] is not None)
self.device_assignment = _exported_device_assignment[0]
def after_conversion(self):
self._restore_context()
def run_fun_tf(self,
args_flat_tf: Sequence[TfVal]
) -> tuple[Sequence[TfVal], Sequence[core.ShapedArray], tree_util.PyTreeDef]:
results = _run_exported_as_tf(args_flat_tf, self.exported)
return results, tuple(self.exported.out_avals), self.exported.out_tree
def get_vjp_fun(self) -> tuple[Callable,
Sequence[core.AbstractValue]]:
return _export._get_vjp_fun(self.fun_jax,
in_tree=self.exported.in_tree,
in_avals=self.exported.in_avals,
has_named_shardings=self.exported._has_named_shardings,
in_named_shardings=self.exported._in_named_shardings,
out_named_shardings=self.exported._out_named_shardings,
in_shardings_hlo=self.exported.in_shardings_hlo,
out_avals=self.exported.out_avals,
out_shardings_hlo=self.exported.out_shardings_hlo,
device_assignment=self.device_assignment,
apply_jit=True)
def dtype_of_val(val: TfVal) -> DType:
"""Computes the TensorFlow dtype using JAX's typing rules.
If the value is a tf.Tensor, it starts with its dtype. If the value is a
constant it uses JAX to infer its dtype. The resulting dtype follows the
JAX type inference rules, and depends on the value of the
JAX_ENABLE_X64 flag.
See README.md for how 64-bit values are treated.
"""
tval, _ = _tfval_to_tensor_jax_dtype(val)
return tval.dtype
@partial(api_util.api_hook, tag="jax2tf_eval_polymorphic_shapes")
def eval_polymorphic_shape(fun_jax: Callable,
*,
polymorphic_shapes=None) -> Callable:
"""Evaluates the output shape in presence of shape polymorphism.
This is done without lowering or executing the function, same as for
`jax.eval_shape`.
Args:
fun_jax: target JAX function to be called. Its arguments and return value
should be JAX arrays, or nested standard Python containers
(tuple/list/dict) thereof (pytrees).
polymorphic_shapes: Specifies input shapes to be treated polymorphically
during shape evaluation. See discussion for `jax2tf.convert`.
.. warning:: The shape-polymorphic lowering is an experimental feature.
Returns: a function that takes `jax.ShapeDtypeStruct`s (or any values
with `.shape` and `.dtype` attributes) corresponding to the inputs for
`fun_jax`, and returns a tuple with:
* the jax.ShapeDtypeStruct corresponding to the result, as for
`jax.eval_shape`. The shape may contain symbolic dimension expressions.
* the value that can be passed to `polymorphic_shapes` for a subsequent
call to `jax2tf.eval_polymorphic_shape`, or `jax2tf.convert`.
For example:
>>> import jax
>>> from jax.experimental import jax2tf
>>> from jax import numpy as jnp
>>>
>>> f = lambda A, x: jnp.sin(jnp.dot(A, x))
>>> A = jax.ShapeDtypeStruct((2000, 3000), jnp.float32)
>>> x = jax.ShapeDtypeStruct((3000, 1000), jnp.float32)
>>> out_spec, out_poly_shape = jax2tf.eval_polymorphic_shape(f, polymorphic_shapes=["a, b", "b, c"])(A, x)
>>> print(out_spec.shape)
("a", "c")
>>> print(out_poly_shape)
(a, c)
>>> res_spec, res_poly_shape = jax2tf.eval_polymorphic_shape(lambda x: x.T, polymorphic_shapes=[out_poly_shape])(out_spec)
>>> print(res_poly_shape)
(c, a)
"""
def do_eval_polymorphic_shape(*args_specs) -> Any:
args_poly_specs = export.symbolic_args_specs(
args_specs, polymorphic_shapes)
res_poly_spec = jax.eval_shape(fun_jax, *args_poly_specs)
# TODO(necula): For now we export the polymorphic shapes using `str`.
res_polymorphic_shape = tree_util.tree_map(lambda r: str(r.shape), res_poly_spec)
return res_poly_spec, res_polymorphic_shape
return do_eval_polymorphic_shape
# Internals
def preprocess_arg_tf(arg_idx: int,
arg_tf: TfVal) -> TfVal:
"""Pre-processes the TF args.
Returns: a tuple with the pre-processed TF arg, the TF shape, and the
JAX dtype.
"""
if not _is_tfval(arg_tf):
msg = (f"Argument {arg_tf} of type {type(arg_tf)} of jax2tf.convert(f) should "
"be NumPy array, scalar, tf.Variable, or tf.Tensor")
raise TypeError(msg)
# May cast the args_flat to JAX types, using JAX's interpretation
# of types of constants.
arg_tf, _ = _tfval_to_tensor_jax_dtype(arg_tf)
# Name input tensors; do this after we have cast the arguments
arg_tf = tf.identity(arg_tf, f"jax2tf_arg_{arg_idx}")
return arg_tf
def _make_custom_gradient_fn_tf(fun_jax,
*,
impl: NativeSerializationImpl,
with_gradient: bool,
args_specs, kwargs_specs,
args_tf: Sequence[TfVal],
outs_avals: Sequence[core.ShapedArray],
outs_tf: Sequence[TfVal]):
"""Prepares the TF function to be used with tf.custom_gradient.
Args:
impl: the serialization implementation details
with_gradient: whether to include a tf.custom_gradient
args_specs, kwargs_specs: the jax.ShapeDtypeArrays for the args and kwargs
args_tf: the flattened TF arguments of the primal function
outs_avals: the flattened output JAX abstract values of the primal function
outs_tf: the flattened TF outputs of the primal function
"""
def grad_fn_tf(*out_cts_flat_tf: TfVal,
variables=None):
if variables:
raise ValueError(
"Unexpected variables used in forward pass. "
"This should not happen for first-order differentiation. "
f"{variables=}")
# TODO: enable higher-order gradients
with tf.name_scope("jax2tf_vjp"):
def fix_out_ct(out_ct_tf, out_ct_aval: core.ShapedArray, out_tf: TfVal):
# If the primal function has outputs of integer or bool types, and if we are
# under a tf.function context, then TF will pass None in _out_cts_flat
# in place of these values. We should change these to float0 or
# else JAX gets unhappy. See issue #6975.
if out_ct_tf is not None:
return out_ct_tf
assert core.primal_dtype_to_tangent_dtype(out_ct_aval.dtype) == dtypes.float0, f"{out_ct_tf=}"
# Note that out_ct_aval.shape contains dimension variable from the
# primal function scope. We use tf.zeros_like to make a 0 of the right shape.
return tf.zeros_like(out_tf, dtype=_tf_np_dtype_for_float0)
out_cts_fixed_flat_tf = tuple(map(fix_out_ct, out_cts_flat_tf, outs_avals, outs_tf))
vjp_args_flat_tf = tuple(args_tf) + out_cts_fixed_flat_tf
fun_vjp_jax, vjp_in_avals = impl.get_vjp_fun()
vjp_polymorphic_shapes = tuple(
# pyrefly: ignore[missing-attribute]
str(a.shape) # Note: may be _DimExpr, not just DimVar
for a in vjp_in_avals)
in_cts_flat = convert(
fun_vjp_jax,
with_gradient=with_gradient,
polymorphic_shapes=vjp_polymorphic_shapes,
**impl.convert_kwargs)(*vjp_args_flat_tf)
# We do not need to fix the in_cts because the TF gradient machinery
# will adjust the unconnected gradients and those for integer types.
return in_cts_flat
return grad_fn_tf
def _run_exported_as_tf(args_flat_tf: Sequence[TfVal],
exported: export.Exported,
) -> Sequence[TfVal]:
"""Runs the `exported` as an XlaCallModule TF op.
Returns: the flattened tuple of results.
"""
args_avals = exported.in_avals
# TF values may be integer types for float0
def _convert_value(val, aval):
# Check the shape
assert all(d_aval == d_val
for d_aval, d_val in zip(aval.shape, val.shape)
if core.is_constant_dim(d_aval)), (aval, val)
conversion_dtype = _to_tf_dtype(aval.dtype)
if conversion_dtype != aval.dtype:
return tf.cast(val, conversion_dtype)
else:
return val
args_flat_tf = tuple(map(_convert_value, args_flat_tf, args_avals))
out_shapes_tf = tuple(
tuple(d if core.is_constant_dim(d) else None
for d in out_aval.shape)
for out_aval in exported.out_avals)
out_types = tuple(_to_tf_dtype(out_aval.dtype) for out_aval in exported.out_avals)
kept_args_flat_tf = [atf for i, atf in enumerate(args_flat_tf) if i in exported.module_kept_var_idx]
version = exported.calling_convention_version
try:
get_max_supported_version = tfxla.call_module_maximum_supported_version
except AttributeError:
get_max_supported_version = None
if get_max_supported_version:
max_supported_version = get_max_supported_version()
else:
max_supported_version = 6
if version > max_supported_version:
raise NotImplementedError(
"XlaCallModule from your TensorFlow installation supports up to "
f"serialization version {max_supported_version} but the serialized "
f"module needs version {version}. "
"You should upgrade TensorFlow, e.g., to tf_nightly."
)
call_module_attrs: dict[str, Any] = dict(
version=version,
Tout=out_types,
Sout=out_shapes_tf,
function_list=[
concrete_fn.function_def.signature.name
for concrete_fn in _thread_local_state.call_tf_concrete_function_list
] if _thread_local_state.call_tf_concrete_function_list is not None else [],
# We always set has_token_input_output because it requires real tokens
# for versions less than 9 and is not used starting with version 9.
has_token_input_output=False
)
call_module_attrs["platforms"] = tuple(p.upper() for p in exported.platforms)
if version >= 6:
call_module_attrs["disabled_checks"] = tuple(
str(dc)
for dc in exported.disabled_safety_checks)
else:
if version >= 3:
if DisabledSafetyCheck.platform() in exported.disabled_safety_checks:
call_module_attrs["platforms"] = () # No platform checking
if version >= 10:
call_module_attrs["use_shardy_partitioner"] = (
config.use_shardy_partitioner.value
)
if logging.vlog_is_on(3):
# We already logged the MLIR module when we exported it.
logging.vlog(3, "XlaCallModule %s", str(call_module_attrs))
call_module_attrs["module"] = exported.mlir_module_serialized
# Apply the shardings on arguments and results for pjit. This is redundant
# because the mlir_module_text will already contain the shardings, but it
# makes it easier for tools like the TPU inference converter to see the
# sharding without digging into the `module` attribute of the `XlaCallModule`
# op, in the same way as it is done for the legacy jax2tf conversion.
# Do not apply XlaSharding for REPLICATED, on inputs and outputs.
# This is an agreed convention, and also improves usability under TF eager.
# See b/255511660.
kept_in_shardings = []
for i in exported.module_kept_var_idx:
if exported._has_named_shardings:
in_sharding_hlo = _export.named_to_hlo_sharding(
exported._in_named_shardings[i],
exported.in_avals[i])
else:
in_sharding_hlo = exported.in_shardings_hlo[i]
kept_in_shardings.append(in_sharding_hlo)
args_flat_tf = tuple(
map(partial(_shard_value,
skip_replicated_sharding=tf.executing_eagerly()),
kept_args_flat_tf, kept_in_shardings))
res = tfxla.call_module(args_flat_tf, **call_module_attrs)
# TODO(b/278940799): Replace the TF v1 API with public TF2 API.
# Add the custom call tf.function into the default graph, so those functions
# will be available during tf.SavedModel.save.
if _thread_local_state.call_tf_concrete_function_list is not None:
for concrete_fn in _thread_local_state.call_tf_concrete_function_list:
tf.compat.v1.get_default_graph()._add_function_recursive(
concrete_fn._inference_function
)
if exported._has_named_shardings:
out_shardings_hlo = tuple(
_export.named_to_hlo_sharding(s, a)
for s, a in zip(exported._out_named_shardings, exported.out_avals))
else:
out_shardings_hlo = exported.out_shardings_hlo
res = list(map(partial(_shard_value,
skip_replicated_sharding=tf.executing_eagerly()),
res, out_shardings_hlo))
res = tuple(map(_convert_value, res, exported.out_avals))
return res
def _jax_physical_aval(aval: core.ShapedArray) -> core.ShapedArray:
"""Converts JAX avals from logical to physical, if relevant.
JAX might have avals whose logical vs physical shape/dtype may
differ, and only the physical view is expected to possibly
relate to TF. TF impl rules should operate on the physical form.
A JAX logical aval might even correspond, in principle, to several
physical avals, but we don't support those here. Instead we assert
there is only one and return it.
"""
physical_aval = core.physical_aval(aval)
assert (len(physical_aval.shape) >= len(aval.shape) and
physical_aval.shape[:len(aval.shape)] == aval.shape), (physical_aval, aval)
return physical_aval
def _jax_physical_dtype(dtype):
# assuming () is a fine stand-in shape
return _jax_physical_aval(core.ShapedArray((), dtype)).dtype
def _aval_to_tf_shape(aval: core.ShapedArray) -> tuple[int | None, ...]:
"""Generate a TF shape, possibly containing None for polymorphic dimensions."""
aval = _jax_physical_aval(aval)
return tuple(map(lambda d: None if export.is_symbolic_dim(d) else d,
aval.shape))
# In the TF world, we represent float0 as zeros of this type.
# We pick bool because this is what JAX uses when it lowers float0 to HLO.
_tf_np_dtype_for_float0 = np.bool_
def _to_tf_dtype(jax_dtype):
# Note that converting _to_tf_dtype and _to_jax_dtype are not inverses,
# due to float0 and 64-bit behavior.
try:
jax_dtype = _jax_physical_dtype(jax_dtype)
except TypeError:
# `jax_dtype` isn't actually a valid jax dtype (e.g. it is
# tf.float32), so there is no physical dtype anyway
pass
if jax_dtype == dtypes.float0:
jax_dtype = _tf_np_dtype_for_float0
return tf.dtypes.as_dtype(jax_dtype)
def _to_jax_dtype(tf_dtype):
# Note that converting _to_tf_dtype and _to_jax_dtype are not inverses,
# due to float0 and 64-bit behavior.
dt = dtypes.canonicalize_dtype(tf_dtype.as_numpy_dtype)
if dt not in dtypes._jax_dtype_set:
raise TypeError(f"dtype {dt} is not a valid JAX array "
"type. Only arrays of numeric types are supported by JAX.")
return dt
def _tfval_to_tensor_jax_dtype(val: TfVal,
jax_dtype: DType | None = None,
memoize_constants=False) -> tuple[TfVal, DType]:
"""Converts a scalar, ndarray, or tf.Tensor to a tf.Tensor with proper type.
If `jax_dtype` is missing, uses JAX typing rules.
See README.md for details regarding 64-bit values.
Args:
val: a scalar, ndarray, tf.Tensor, or tf.Variable
jax_dtype: an optional dtype to use. If missing, uses JAX type inference
rules for constants.
memoize_constants: whether to memoize TF constants. We can't do this
everywhere, we may be outside of a conversion scope.
Returns:
a tuple with a tf.Tensor with the type as needed by JAX, and the JAX type.
"""
if isinstance(val, (tf.Tensor, tf.Variable)):
jax_dtype = jax_dtype or _to_jax_dtype(val.dtype) # Give JAX a chance to pick the type
conversion_dtype = _to_tf_dtype(jax_dtype)
if conversion_dtype != val.dtype: # May need to cast for 64-bit values
return tf.cast(val, conversion_dtype), jax_dtype
else:
return val, jax_dtype
else: # A constant
jax_dtype = jax_dtype or core.typeof(val).dtype
# TODO(document): We assume that the value of a constant does not
# change through the scope of the function. But it may be an ndarray, ...
# JAX has the same problem when generating HLO.
const_key = (id(val), jax_dtype)
# Since we use id(val) as a cache key, we have to make sure that we keep
# the previous `val` alive. Otherwise, for a ndarray, it can get garbage
# collected and reused for a different value, which would create correctness
# issues. We keep the `val` alive by storing in the cache the pair
# `(val, tf_val)`.
# Only memoize non-scalars. JAX will lift all non-scalar constants as
# Jaxpr consts, to the top level of the Jaxpr. This ensures that we see them
# early, when entering the Jaxpr, so we create the tf.const early and its
# scope is the entire Jaxpr.
do_memoize = (memoize_constants and np.size(val) > 1 and _thread_local_state.constant_cache is not None)
if do_memoize:
_, tf_val = _thread_local_state.constant_cache.get(const_key, (None, None))
else:
tf_val = None
if tf_val is None:
conversion_dtype = _to_tf_dtype(jax_dtype)
# The float0 type is not known to TF.
if jax_dtype == dtypes.float0:
val = np.zeros(np.shape(val), conversion_dtype.as_numpy_dtype)
if hasattr(val, 'dtype') and dtypes.issubdtype(val.dtype, dtypes.extended):
val = val.dtype._rules.physical_const(val)
tf_val = tf.convert_to_tensor(val, dtype=conversion_dtype)
if do_memoize:
_thread_local_state.constant_cache[const_key] = (val, tf_val)
return tf_val, jax_dtype
PartitionsOrReplicated = Union[tuple[int, ...], None]
def split_to_logical_devices(tensor: TfVal,
partition_dimensions: PartitionsOrReplicated):
"""Like TPUMPStrategy.experimental_split_to_logical_devices.
For jax2tf purposes we want to avoid needing to thread the `strategy` object
through the generated computation. It seems that the original function needs
the strategy object only for error checking, which we assume is done upstream
by JAX.
Args:
tensor: Input tensor to annotate.
partition_dimensions: A list of integers, with one integer per tensor
dimension, specifying in how many parts the dimension should be split. The
product of integers must equal the number of devices per replica.
use_sharding_op: whether to use a sharding op, or not.
Returns:
an annotated tensor.
"""
# TODO: this is only for sharded_jit. Either remove, or implement in terms
# of _shard_values.
if partition_dimensions is None:
return xla_sharding.replicate(tensor, use_sharding_op=True)
num_partition_splits = math.prod(partition_dimensions)
tile_assignment = np.arange(num_partition_splits).reshape(
partition_dimensions)
return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True)
def _shard_value(val: TfVal,
sd: xla_client.HloSharding | None, *,
skip_replicated_sharding: bool) -> TfVal:
"""Apply sharding to a TfVal."""
if sd is None:
return val
sharding_proto = sd.to_proto()
if (skip_replicated_sharding and
op_shardings.is_hlo_sharding_replicated(sd)):
return val
# Tensorflow heavily relies on tile_assignment_devices proto fields specific
# to V1 sharding format, falling back to this format.
if (
not sharding_proto.tile_assignment_devices
and sharding_proto.iota_reshape_dims
):
tad = list(
np.arange(math.prod(sharding_proto.tile_assignment_dimensions))
.reshape(sharding_proto.iota_reshape_dims)
.transpose(sharding_proto.iota_transpose_perm)
.flat
)
else:
tad = sharding_proto.tile_assignment_devices
# To use xla_sharding.py, we must have a xla_data_pb2.OpSharding.
xla_sharding_v1_proto: xla_data_pb2.OpSharding = xla_data_pb2.OpSharding(
type=int(sharding_proto.type),
tile_assignment_dimensions=sharding_proto.tile_assignment_dimensions,
tile_assignment_devices=tad,
replicate_on_last_tile_dim=sharding_proto.replicate_on_last_tile_dim,
last_tile_dims=sharding_proto.last_tile_dims,
)
# Shardy requires V2 sharding format.
if config.use_shardy_partitioner.value:
xla_sharding_v2_proto: xla_data_pb2.OpSharding = xla_data_pb2.OpSharding(
type=int(sharding_proto.type),
tile_assignment_dimensions=sharding_proto.tile_assignment_dimensions,
tile_assignment_devices=sharding_proto.tile_assignment_devices,
iota_reshape_dims=sharding_proto.iota_reshape_dims,
iota_transpose_perm=sharding_proto.iota_transpose_perm,
replicate_on_last_tile_dim=sharding_proto.replicate_on_last_tile_dim,
last_tile_dims=sharding_proto.last_tile_dims,
)
else:
xla_sharding_v2_proto = None
if tf_context.executing_eagerly():
raise ValueError(
"A jit function with sharded arguments or results must be used under a `tf.function` context. "
"See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#support-for-partitioning for a discussion")
tf_version = tuple(int(v) for v in tf.__version__.split(".")[:2])
# apply_to_tensor comes from a tensorflow package, check the tensorflow
# version to make sure that it has the sharding_v2_proto parameter.
if tf_version < (2, 20):
return xla_sharding.Sharding(proto=xla_sharding_v1_proto).apply_to_tensor(
val, use_sharding_op=True
)
return xla_sharding.Sharding(proto=xla_sharding_v1_proto).apply_to_tensor(
val, use_sharding_op=True, sharding_v2_proto=xla_sharding_v2_proto
)
def _register_checkpoint_pytrees():
"""Registers TF custom container types as pytrees."""
m = tf.Module()
# The types here are automagically changed by TensorFlow's checkpointing
# infrastructure.
m.a = (tf.Module(), tf.Module())
m.b = [tf.Module(), tf.Module()]
m.c = {"a": tf.Module()}
tuple_wrapper = type(m.a)
list_wrapper = type(m.b)
dict_wrapper = type(m.c)
# TF AutoTrackable swaps container types out for wrappers.
assert tuple_wrapper is not tuple
assert list_wrapper is not list
assert dict_wrapper is not dict
jax.tree_util.register_pytree_node(tuple_wrapper, lambda xs:
(tuple(xs), None), lambda _, xs: tuple(xs))
jax.tree_util.register_pytree_node(list_wrapper, lambda xs: (tuple(xs), None),
lambda _, xs: list(xs))
jax.tree_util.register_pytree_node(
dict_wrapper,
lambda s: (tuple(s.values()), tuple(s.keys())),
lambda k, xs: dict_wrapper(zip(k, xs)))
_register_checkpoint_pytrees()