hand
This commit is contained in:
@@ -0,0 +1,23 @@
|
||||
# Copyright 2020 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax.experimental.jax2tf.jax2tf import (
|
||||
convert as convert,
|
||||
eval_polymorphic_shape as eval_polymorphic_shape,
|
||||
dtype_of_val as dtype_of_val,
|
||||
split_to_logical_devices as split_to_logical_devices,
|
||||
DisabledSafetyCheck as DisabledSafetyCheck,
|
||||
PolyShape as PolyShape # TODO: deprecate
|
||||
)
|
||||
from jax.experimental.jax2tf.call_tf import call_tf as call_tf
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,727 @@
|
||||
# Copyright 2021 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Allows JAX to call TensorFlow functions with support for autodiff.
|
||||
|
||||
**Experimental: please give feedback, and expect changes.**
|
||||
|
||||
This module introduces the function :func:`call_tf` that allows JAX to call
|
||||
TensorFlow functions.
|
||||
|
||||
For examples and details, see
|
||||
https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax.
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
import dataclasses
|
||||
import functools
|
||||
from typing import cast, Any
|
||||
|
||||
from absl import logging
|
||||
import jax
|
||||
from jax import dlpack
|
||||
from jax import dtypes
|
||||
from jax import numpy as jnp
|
||||
from jax import tree_util
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import effects
|
||||
from jax._src import literals
|
||||
from jax._src import util
|
||||
from jax._src.lib import _jax
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax.experimental import roofline
|
||||
from jax.experimental.jax2tf import jax2tf as jax2tf_internal
|
||||
from jax._src.interpreters import mlir
|
||||
import ml_dtypes
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
map = util.safe_map
|
||||
zip = util.safe_zip
|
||||
|
||||
TfConcreteFunction = Any
|
||||
TfVal = jax2tf_internal.TfVal
|
||||
|
||||
# The platforms for which to use DLPack to avoid copying (only works on GPU
|
||||
# and CPU at the moment, and only for Array). For CPU we don't need
|
||||
# DLPack, if we are careful.
|
||||
_DLPACK_PLATFORMS = ("gpu",)
|
||||
|
||||
class UnspecifiedOutputShapeDtype:
|
||||
pass
|
||||
|
||||
def call_tf(
|
||||
callable_tf: Callable,
|
||||
has_side_effects=True,
|
||||
ordered=False,
|
||||
output_shape_dtype=UnspecifiedOutputShapeDtype(),
|
||||
call_tf_graph=False,
|
||||
) -> Callable:
|
||||
"""Calls a TensorFlow function from JAX, with support for reverse autodiff.
|
||||
|
||||
The ``callable_tf`` will be called with TensorFlow-compatible arguments (
|
||||
numpy.ndarray, ``tf.Tensor`` or ``tf.Variable``) or pytrees thereof. The
|
||||
function must return the same type of results.
|
||||
|
||||
If ``call_tf`` appears in a JAX staging context (:func:`jax.jit`,
|
||||
or :func:`jax.pmap`, or a control-flow primitive) then
|
||||
``callable_tf`` will be compiled with ``tf.function(callable_tf,
|
||||
jit_compile=True)``
|
||||
and the resulting XLA computation will be embedded in JAX's XLA computation.
|
||||
|
||||
If ``call_tf`` appears outside a JAX staging context, it will be called inline
|
||||
using TensorFlow eager mode.
|
||||
|
||||
The ``call_tf`` supports JAX's reverse-mode autodiff, in which case the
|
||||
``callable_tf`` will be differentiated using ``tf.GradientTape``. This means
|
||||
that the gradient will be TensorFlow-accurate, e.g., will respect the
|
||||
custom gradients that may be defined for the code in ``callable_tf``.
|
||||
|
||||
For an example and more details see the
|
||||
`README
|
||||
<https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax>`_.
|
||||
|
||||
Args:
|
||||
callable_tf: a TensorFlow Callable that can take a pytree of TensorFlow
|
||||
arguments.
|
||||
has_side_effects: if True then it ensures that instances of this primitive
|
||||
are not removed or replicated by JAX optimizations such as dead-code
|
||||
elimination.
|
||||
ordered: If true, calls are modeled as having ordered effects.
|
||||
output_shape_dtype: An optional declaration of the expected shape and dtype
|
||||
of the result of the called TensorFlow function. If given it will be used
|
||||
during JAX tracing to form the abstract values of the results of the
|
||||
`call_tf`. If not given then we form a `tf.Graph` for the called
|
||||
TensorFlow function and we use the TensorFlow-inferred shapes and types.
|
||||
Must be a pytree matching the structure of the nested structure returned
|
||||
from the TensorFlow function, containing objects with `.shape` and
|
||||
`.dtype` attributes, e.g., `jax.ShapeDtypeStruct` or `jax.Array`.
|
||||
call_tf_graph: EXPERIMENTAL, DO NOT USE. We may change the name in the
|
||||
future.
|
||||
|
||||
Returns: a JAX callable that can be invoked with JAX pytree arguments, in
|
||||
op-by-op mode or in a staged context. This callable can be used with JAX's
|
||||
reverse-mode autodiff (:func:`jax.grad`).
|
||||
"""
|
||||
@jax.custom_vjp
|
||||
def make_call(*args_jax):
|
||||
"""We wrap it all in `make_call` so that we can attach custom VJP."""
|
||||
|
||||
args_flat_jax, args_treedef = tree_util.tree_flatten(args_jax)
|
||||
# Canonicalize the arguments; e.g., makes them x32 if JAX is in 32-bit mode
|
||||
def canonical_arg(v):
|
||||
v = v if getattr(v, "dtype", None) else np.asarray(v)
|
||||
dtype = dtypes.canonicalize_dtype(v.dtype)
|
||||
if dtype != v.dtype:
|
||||
v = v.astype(dtype)
|
||||
return v
|
||||
|
||||
args_flat_jax = tuple(map(canonical_arg, args_flat_jax))
|
||||
def make_tensorspec(a_jax):
|
||||
a_tf_dtype = jax2tf_internal._to_tf_dtype(a_jax.dtype)
|
||||
a_tf_shape = [d if core.is_constant_dim(d) else None for d in getattr(a_jax, "shape", ())]
|
||||
return tf.TensorSpec(a_tf_shape, a_tf_dtype)
|
||||
args_flat_sig_tf = tuple(map(make_tensorspec, args_flat_jax))
|
||||
|
||||
if not isinstance(output_shape_dtype, UnspecifiedOutputShapeDtype):
|
||||
output_shape_dtype_flat, output_shape_dtype_tree = tree_util.tree_flatten(output_shape_dtype)
|
||||
output_avals = tuple(core.ShapedArray(st.shape, st.dtype) for st in output_shape_dtype_flat)
|
||||
else:
|
||||
output_avals, output_shape_dtype_tree = None, None
|
||||
|
||||
res_treedef = None # We'll store here the result treedef
|
||||
res_tf_flat = None # For error reporting
|
||||
# The function below will be called at least once, either in eager
|
||||
# mode during jax2tf_call_tf or in graph mode during _get_concrete_function_tf()
|
||||
def callable_flat_tf(*args_tf_flat: TfVal) -> Sequence[TfVal]:
|
||||
args_tf = args_treedef.unflatten(args_tf_flat)
|
||||
res_tf = callable_tf(*args_tf)
|
||||
|
||||
# b/279454591: When `callable_tf` is a tf function with zero outputs, it
|
||||
# returns a `StatefulPartitionedCall` (if the function is stateful) or
|
||||
# `PartitionedCall` (if the function is stateless) op instead of
|
||||
# tf.Tensors. We work around this issue by replacing the output `res_tf`
|
||||
# with an empty list.
|
||||
|
||||
if isinstance(res_tf, tf.Operation):
|
||||
assert (
|
||||
res_tf.type == "StatefulPartitionedCall"
|
||||
or res_tf.type == "PartitionedCall"
|
||||
)
|
||||
t_out = res_tf.get_attr("Tout")
|
||||
# t_out should be an empty list.
|
||||
assert not t_out, (
|
||||
"The TF function returned an unexpected result, please check its"
|
||||
f" function body. res_tf = {res_tf}"
|
||||
)
|
||||
res_tf = t_out
|
||||
|
||||
nonlocal res_treedef, res_tf_flat
|
||||
res_tf_flat, res_treedef_now = tree_util.tree_flatten(res_tf)
|
||||
assert res_treedef is None or res_treedef == res_treedef_now, (
|
||||
f"Subsequent calls had different results. Previous {res_treedef} and now {res_treedef_now}")
|
||||
res_treedef = res_treedef_now
|
||||
if output_avals is not None:
|
||||
if res_treedef != output_shape_dtype_tree:
|
||||
raise ValueError(
|
||||
"The pytree of the TensorFlow function results does not match the "
|
||||
"pytree of the declared output_shape_dtype:\n"
|
||||
f"results pytree: {res_treedef}\noutput_shape_dtype tree: {output_shape_dtype_tree}")
|
||||
assert len(output_avals) == len(res_tf_flat)
|
||||
|
||||
checked_res_tf_flat = [
|
||||
check_tf_result(i, r_tf, r_aval)
|
||||
for i, (r_tf, r_aval) in enumerate(
|
||||
zip(res_tf_flat,
|
||||
(output_avals
|
||||
if output_avals is not None
|
||||
else (None,) * len(res_tf_flat))))]
|
||||
return checked_res_tf_flat
|
||||
|
||||
# Prepare a tf.function ahead of time, to cache the concrete functions. This
|
||||
# won't be used in op-by-op execution mode.
|
||||
function_flat_tf = tf.function(
|
||||
callable_flat_tf, autograph=False, jit_compile=not call_tf_graph)
|
||||
|
||||
res_jax_flat = call_tf_p.bind(
|
||||
*args_flat_jax,
|
||||
# Carry the actual function such that op-by-op call can call in TF eager mode.
|
||||
callable_flat_tf=callable_flat_tf,
|
||||
function_flat_tf=function_flat_tf,
|
||||
args_flat_sig_tf=args_flat_sig_tf,
|
||||
output_avals=output_avals,
|
||||
has_side_effects=has_side_effects,
|
||||
ordered=ordered,
|
||||
call_tf_graph=call_tf_graph,
|
||||
)
|
||||
|
||||
# We must have called callable_flat_tf by nοw
|
||||
assert res_treedef is not None
|
||||
return res_treedef.unflatten(res_jax_flat)
|
||||
|
||||
# Define the fwd and bwd custom_vjp functions
|
||||
def make_call_vjp_fwd(*args_jax):
|
||||
# Return the primal arguments as the residual
|
||||
return make_call(*args_jax), args_jax
|
||||
|
||||
def make_call_vjp_bwd(residual_jax, ct_res_jax):
|
||||
args_jax = residual_jax # residual is the primal argument
|
||||
|
||||
def tf_vjp_fun(args_tf, ct_res_tf):
|
||||
"""Invoke TF gradient."""
|
||||
|
||||
# TF does not like us to watch non-float vars or Nones.
|
||||
def replace_non_float_or_none(arg_tf):
|
||||
if arg_tf is not None and (
|
||||
arg_tf.dtype.is_floating or arg_tf.dtype.is_complex
|
||||
):
|
||||
return arg_tf
|
||||
else:
|
||||
# When watched, this will be ignored. When used in results it will
|
||||
# result in a floating 0. gradient, which JAX will ignore (and
|
||||
# replace it with a float0)
|
||||
return tf.zeros((), dtype=tf.float32)
|
||||
|
||||
watched_args_tf = tf.nest.map_structure(
|
||||
replace_non_float_or_none, args_tf
|
||||
)
|
||||
with tf.GradientTape(persistent=True) as tape:
|
||||
tape.watch(watched_args_tf)
|
||||
res = callable_tf(*args_tf)
|
||||
|
||||
tf.nest.assert_same_structure(res, ct_res_tf)
|
||||
dres_darg = tape.gradient(
|
||||
tf.nest.map_structure(replace_non_float_or_none, res),
|
||||
sources=watched_args_tf,
|
||||
output_gradients=ct_res_tf,
|
||||
unconnected_gradients=tf.UnconnectedGradients.ZERO,
|
||||
)
|
||||
|
||||
dres_darg = tree_util.tree_map(
|
||||
lambda x: x if x is None else tf.convert_to_tensor(x),
|
||||
dres_darg,
|
||||
)
|
||||
|
||||
# callable_tf may mutate (the structure of) args_tf, thus we check against
|
||||
# watched_args_tf which should be structurally the same as the original
|
||||
# args_tf.
|
||||
tf.nest.assert_same_structure(dres_darg, watched_args_tf)
|
||||
return dres_darg
|
||||
|
||||
# Use call_tf to call the VJP function
|
||||
ct_args_jax = call_tf(tf_vjp_fun)(args_jax, ct_res_jax)
|
||||
# We must make the float0s that JAX expects
|
||||
def fix_float0(arg_jax, ct_arg_jax):
|
||||
if arg_jax is None:
|
||||
return None
|
||||
arg_dtype = dtypes.result_type(arg_jax) # May be scalar
|
||||
ct_arg_dtype = core.primal_dtype_to_tangent_dtype(arg_dtype)
|
||||
if ct_arg_dtype != ct_arg_jax.dtype:
|
||||
return ad_util.zeros_like_aval(core.ShapedArray(np.shape(arg_jax),
|
||||
ct_arg_dtype))
|
||||
return ct_arg_jax
|
||||
|
||||
ct_args_jax_fixed = tree_util.tree_map(fix_float0, args_jax, ct_args_jax,
|
||||
is_leaf=lambda x: x is None)
|
||||
return ct_args_jax_fixed
|
||||
|
||||
make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd)
|
||||
return util.wraps(callable_tf)(make_call)
|
||||
|
||||
|
||||
def check_tf_result(idx: int, r_tf: TfVal, r_aval: core.ShapedArray | None) -> TfVal:
|
||||
# Check that the TF function returns values of expected types. This
|
||||
# improves error reporting, preventing hard-to-diagnose errors downstream
|
||||
try:
|
||||
jax2tf_internal._tfval_to_tensor_jax_dtype(r_tf)
|
||||
except Exception as e:
|
||||
msg = ("The called TF function returns a result that is not "
|
||||
f"convertible to JAX: {r_tf}.")
|
||||
raise ValueError(msg) from e
|
||||
|
||||
if r_aval is None:
|
||||
return r_tf
|
||||
# We convert to TF type, and canonicalize to 32-bit if necessary
|
||||
r_aval_dtype_tf = jax2tf_internal._to_tf_dtype(r_aval.dtype)
|
||||
# Checking shapes is trickier in presence of dynamic shapes. I wish we could
|
||||
# check at runtime that the returned shape matches the declared shape. I wish
|
||||
# that tf.ensure_shape did this, but it can only take shapes that contain None
|
||||
# not computed shapes. However, in eager mode we should be able to resolve
|
||||
# the declared shapes to constants and we get better checking.
|
||||
r_aval_shape_tf = jax2tf_internal._aval_to_tf_shape(r_aval)
|
||||
# We do as much checking as we can here, instead of relying on tf.ensure_shape
|
||||
# because the latter gives different errors in eager vs. compiled mode.
|
||||
# TODO(b/279454591): This strange error is from TF. Eager function suppose
|
||||
# return tf Val with concrete shape but not. Here we change exception to warn
|
||||
# and bypass it. This case need revisit on TF side.
|
||||
try:
|
||||
_ = len(r_tf.shape)
|
||||
except ValueError as e:
|
||||
msg = (
|
||||
"The shape check test cannot be performed because the shape of the"
|
||||
"`r_tf` tensor cannot be obtained."
|
||||
f"r_tf = {r_tf}, r_aval = {r_aval}"
|
||||
)
|
||||
msg += str(e)
|
||||
logging.warning(msg)
|
||||
return r_tf
|
||||
if (r_tf.dtype != r_aval_dtype_tf or
|
||||
len(r_tf.shape) != len(r_aval_shape_tf) or
|
||||
any(r_aval_d is not None and r_tf_d is not None and r_aval_d != r_tf_d
|
||||
for r_tf_d, r_aval_d in zip(r_tf.shape, r_aval_shape_tf))):
|
||||
msg = ("The shapes or dtypes returned by the TensorFlow function "
|
||||
"do not match the declared output_shape_dtype:\n"
|
||||
f"Result[{idx}] is {r_tf.dtype}[{r_tf.shape}] vs. expected {r_aval_dtype_tf}[{r_aval_shape_tf}]")
|
||||
raise ValueError(msg)
|
||||
# At this point tf.ensure_shape does not do much, it should never throw an
|
||||
# error, albeit it may refine the shape a bit.
|
||||
return tf.ensure_shape(r_tf, r_aval_shape_tf)
|
||||
|
||||
|
||||
call_tf_p = core.Primitive("call_tf")
|
||||
call_tf_p.multiple_results = True
|
||||
|
||||
# The impl will be used in op-by-op mode and calls callable_tf in TF eager mode.
|
||||
def _call_tf_impl(*args_jax_flat, callable_flat_tf, **_):
|
||||
# On GPU we use dlpack to avoid copies of data to the host.
|
||||
def _arg_jax_to_tf(arg_jax):
|
||||
if (isinstance(arg_jax, jax.Array) and
|
||||
list(arg_jax.devices())[0].platform in _DLPACK_PLATFORMS and
|
||||
dlpack.is_supported_dtype(arg_jax.dtype)):
|
||||
return tf.experimental.dlpack.from_dlpack(arg_jax.__dlpack__())
|
||||
# The following avoids copies to the host on CPU, always for Array
|
||||
# and even for ndarray if they are sufficiently aligned.
|
||||
# TODO(necula): on TPU this copies to the host!
|
||||
if getattr(arg_jax, 'dtype', None) == dtypes.float0:
|
||||
return tf.zeros(shape=arg_jax.shape,
|
||||
dtype=jax2tf_internal._tf_np_dtype_for_float0)
|
||||
if isinstance(arg_jax, tuple(literals.typed_scalar_types)):
|
||||
# Make sure to preserve the JAX dtype for TypedInt, etc.
|
||||
return tf.constant(np.asarray(arg_jax, dtype=arg_jax.dtype))
|
||||
return tf.constant(np.asarray(arg_jax))
|
||||
|
||||
args_tf_flat = tuple(map(_arg_jax_to_tf, args_jax_flat))
|
||||
with jax2tf_internal.inside_call_tf():
|
||||
# Call in TF eager mode
|
||||
res_tf_flat = callable_flat_tf(*args_tf_flat)
|
||||
|
||||
def _res_tf_to_jax(res_tf: TfVal):
|
||||
res_tf, jax_dtype = jax2tf_internal._tfval_to_tensor_jax_dtype(res_tf)
|
||||
if isinstance(res_tf, tf.Tensor) and dlpack.is_supported_dtype(jax_dtype):
|
||||
res_tf_platform = tf.DeviceSpec.from_string(res_tf.backing_device).device_type
|
||||
res_jax_platform = res_tf_platform.lower()
|
||||
if res_jax_platform in _DLPACK_PLATFORMS:
|
||||
return jax.dlpack.from_dlpack(res_tf)
|
||||
|
||||
# When working with a bfloat16 scalar tf.Tensor,np.asarray() can fail.
|
||||
# To handle this special case, we create a numpy copy.
|
||||
if res_tf.shape == tf.TensorShape([]) and res_tf.dtype == tf.bfloat16:
|
||||
return jax.device_put(jnp.array(res_tf.numpy()))
|
||||
else:
|
||||
return jax.device_put(np.asarray(res_tf))
|
||||
|
||||
return list(map(_res_tf_to_jax, res_tf_flat))
|
||||
|
||||
|
||||
call_tf_p.def_impl(_call_tf_impl)
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _get_concrete_function_tf(function_flat_tf, args_flat_sig_tf): # -> tf.ConcreteFunction
|
||||
with jax2tf_internal.inside_call_tf():
|
||||
return function_flat_tf.get_concrete_function(*args_flat_sig_tf)
|
||||
|
||||
|
||||
# Mark the effectful instances of call_tf
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CallTfEffect(effects.Effect):
|
||||
__str__ = lambda _: "CallTfEffect"
|
||||
|
||||
call_tf_effect = CallTfEffect()
|
||||
|
||||
effects.lowerable_effects.add_type(CallTfEffect)
|
||||
effects.control_flow_allowed_effects.add_type(CallTfEffect)
|
||||
effects.remat_allowed_effects.add_type(CallTfEffect)
|
||||
effects.custom_derivatives_allowed_effects.add_type(CallTfEffect)
|
||||
|
||||
|
||||
class CallTfOrderedEffect(effects.Effect):
|
||||
__str__ = lambda _: "CallTfOrderedEffect"
|
||||
|
||||
|
||||
call_tf_ordered_effect = CallTfOrderedEffect()
|
||||
|
||||
effects.lowerable_effects.add_type(CallTfOrderedEffect)
|
||||
effects.control_flow_allowed_effects.add_type(CallTfOrderedEffect)
|
||||
effects.remat_allowed_effects.add_type(CallTfOrderedEffect)
|
||||
effects.custom_derivatives_allowed_effects.add_type(CallTfOrderedEffect)
|
||||
effects.ordered_effects.add_type(CallTfOrderedEffect)
|
||||
effects.shardable_ordered_effects.add_type(CallTfOrderedEffect)
|
||||
|
||||
|
||||
def _call_tf_abstract_eval(
|
||||
*args_flat_avals,
|
||||
function_flat_tf,
|
||||
args_flat_sig_tf,
|
||||
has_side_effects,
|
||||
ordered,
|
||||
output_avals,
|
||||
call_tf_graph,
|
||||
**__,
|
||||
):
|
||||
# Called only when we form a Jaxpr, i.e., under jit, scan, etc.
|
||||
effs: set[effects.Effect] = set()
|
||||
if ordered:
|
||||
effs.add(call_tf_ordered_effect)
|
||||
elif has_side_effects:
|
||||
effs.add(call_tf_effect)
|
||||
|
||||
# If no output_avals is given, then we ask TF to infer the output shapes.
|
||||
# We call this even if output_avals is given because it will ensure that
|
||||
# callable_flat_tf is called. Since _get_concrete_function_tf is cached
|
||||
# there is a small cost of calling it more often than needed.
|
||||
concrete_function_flat_tf = _get_concrete_function_tf(function_flat_tf,
|
||||
args_flat_sig_tf)
|
||||
|
||||
# In the case that the tf.function has no return value
|
||||
if len(concrete_function_flat_tf.outputs) == 0:
|
||||
return (), effs
|
||||
|
||||
if output_avals is not None:
|
||||
return output_avals, effs
|
||||
|
||||
def is_fully_known_shape(s):
|
||||
return s.rank is not None and all(d is not None for d in s)
|
||||
|
||||
if all(is_fully_known_shape(s)
|
||||
for s in concrete_function_flat_tf.output_shapes):
|
||||
avals_from_tf = tuple(
|
||||
# We convert to JAX type, and canonicalize to 32-bit if necessary
|
||||
core.ShapedArray(shape, jax2tf_internal._to_jax_dtype(dtype))
|
||||
for dtype, shape in zip(concrete_function_flat_tf.output_dtypes,
|
||||
concrete_function_flat_tf.output_shapes))
|
||||
return avals_from_tf, effs
|
||||
|
||||
msg = ("call_tf cannot call functions whose output has dynamic shape. "
|
||||
f"Found output shapes: {concrete_function_flat_tf.output_shapes}. "
|
||||
"Consider using the `output_shape_dtype` argument to call_tf. "
|
||||
"\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf"
|
||||
" for a discussion.")
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
call_tf_p.def_effectful_abstract_eval(_call_tf_abstract_eval)
|
||||
|
||||
|
||||
def _mlir_type_to_numpy_dtype(type: ir.Type) -> np.dtype:
|
||||
"""Converts an MLIR scalar type to a NumPy dtype."""
|
||||
|
||||
if isinstance(type, ir.IntegerType):
|
||||
type = ir.IntegerType(type)
|
||||
width = type.width
|
||||
if width == 1:
|
||||
return np.dtype(np.bool_)
|
||||
elif width == 8:
|
||||
return np.dtype(np.uint8 if type.is_unsigned else np.int8)
|
||||
elif width == 16:
|
||||
return np.dtype(np.uint16 if type.is_unsigned else np.int16)
|
||||
elif width == 32:
|
||||
return np.dtype(np.uint32 if type.is_unsigned else np.int32)
|
||||
elif width == 64:
|
||||
return np.dtype(np.uint64 if type.is_unsigned else np.int64)
|
||||
else:
|
||||
raise ValueError(f"Unsupported integer width: {width}")
|
||||
|
||||
elif isinstance(type, ir.F16Type):
|
||||
return np.dtype(np.float16)
|
||||
elif isinstance(type, ir.F32Type):
|
||||
return np.dtype(np.float32)
|
||||
elif isinstance(type, ir.F64Type):
|
||||
return np.dtype(np.float64)
|
||||
elif isinstance(type, ir.BF16Type):
|
||||
return np.dtype(ml_dtypes.bfloat16)
|
||||
|
||||
elif isinstance(type, ir.ComplexType):
|
||||
element_type = ir.ComplexType(type).element_type
|
||||
if isinstance(element_type, ir.F32Type):
|
||||
return np.dtype(np.complex64)
|
||||
elif isinstance(element_type, ir.F64Type):
|
||||
return np.dtype(np.complex128)
|
||||
else:
|
||||
raise ValueError(f"Unsupported complex element type: {element_type}")
|
||||
|
||||
else:
|
||||
raise TypeError(f"Unsupported MLIR type for NumPy conversion: {type}")
|
||||
|
||||
|
||||
def _call_tf_lowering(
|
||||
ctx: mlir.LoweringRuleContext,
|
||||
*args_op,
|
||||
platform,
|
||||
function_flat_tf,
|
||||
args_flat_sig_tf,
|
||||
has_side_effects,
|
||||
ordered,
|
||||
call_tf_graph,
|
||||
output_avals,
|
||||
**_,
|
||||
):
|
||||
# We use the same TF lowering device as for the embedding JAX computation.
|
||||
# One example when this is needed is when the code refers to variables on one
|
||||
# device. Or, for sharding annotations (only supported on TPU).
|
||||
|
||||
if platform in ["cpu", "tpu"]:
|
||||
tf_platform = platform.upper()
|
||||
elif platform == "cuda":
|
||||
tf_platform = "GPU"
|
||||
else:
|
||||
raise ValueError("platform {platform} not supported")
|
||||
|
||||
concrete_function_flat_tf = _get_concrete_function_tf(function_flat_tf, args_flat_sig_tf)
|
||||
|
||||
captured_inputs = []
|
||||
if concrete_function_flat_tf.captured_inputs:
|
||||
# The function uses either captured variables or tensors.
|
||||
msg = (
|
||||
"call_tf works best with a TensorFlow function that does not capture "
|
||||
"variables or tensors from the context. "
|
||||
"See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion. "
|
||||
f"The following captures were found {concrete_function_flat_tf.captured_inputs}")
|
||||
logging.warning(msg)
|
||||
for inp in concrete_function_flat_tf.captured_inputs:
|
||||
if inp.dtype == tf.resource: # A variable; lookup by handle
|
||||
inp_vars = [v for v in concrete_function_flat_tf.variables if inp is v.handle]
|
||||
assert len(inp_vars) == 1, f"Found {inp_vars}"
|
||||
captured_inputs.append(inp_vars[0])
|
||||
else:
|
||||
captured_inputs.append(inp)
|
||||
|
||||
# The following use case happens when we call_tf a restored saved model that
|
||||
# includes parameters (hence functions closing over tf.Variable), and then
|
||||
# we jax2tf.convert it with native serialization, under tf.function (or
|
||||
# for saving to saved model). The `np.asarray(inp)` fails because it thinks
|
||||
# it is in TF graph mode. The `tf.init_scope()` lifts out of function-building
|
||||
# graph scopes, and allows us to read the values of the variables
|
||||
with tf.init_scope():
|
||||
captured_ops = tuple(
|
||||
mlir.flatten_ir_values(
|
||||
mlir.ir_constant(np.asarray(inp)) for inp in captured_inputs
|
||||
)
|
||||
)
|
||||
|
||||
if call_tf_graph:
|
||||
with jax2tf_internal.inside_call_tf():
|
||||
return emit_tf_embedded_graph_custom_call(
|
||||
ctx,
|
||||
concrete_function_flat_tf,
|
||||
tuple(args_op) + captured_ops,
|
||||
has_side_effects,
|
||||
ordered,
|
||||
output_avals,
|
||||
)
|
||||
|
||||
def convert_to_spec(x):
|
||||
if isinstance(x, tf.TensorSpec):
|
||||
return x
|
||||
else:
|
||||
return tf.TensorSpec.from_tensor(x)
|
||||
|
||||
args_tf_flat = [convert_to_spec(a) for a in args_flat_sig_tf]
|
||||
|
||||
with jax2tf_internal.inside_call_tf():
|
||||
try:
|
||||
func_tf_hlo = function_flat_tf.experimental_get_compiler_ir(
|
||||
*args_tf_flat
|
||||
)(stage="hlo_serialized", platform_name=tf_platform)
|
||||
except Exception as e:
|
||||
msg = ("Error compiling TensorFlow function (see below for the caught exception)." +
|
||||
"\ncall_tf can used " +
|
||||
"in a staged context (under jax.jit, lax.scan, etc.) only with " +
|
||||
"compilable functions with static output shapes.\n" +
|
||||
"See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion." +
|
||||
"\n\nCaught TensorFlow exception: " + str(e))
|
||||
raise ValueError(msg) from e
|
||||
|
||||
stablehlo = _jax.mlir.hlo_to_stablehlo(func_tf_hlo)
|
||||
submodule = ir.Module.parse(stablehlo)
|
||||
symtab = ir.SymbolTable(submodule.operation)
|
||||
main = cast(func_dialect.FuncOp, symtab["main"])
|
||||
callee_result_types = main.type.results
|
||||
fn = mlir.merge_mlir_modules(ctx.module_context.module,
|
||||
f"call_tf_{function_flat_tf.name}",
|
||||
submodule,
|
||||
dst_symtab=ctx.module_context.symbol_table)
|
||||
call = func_dialect.CallOp(callee_result_types,
|
||||
ir.FlatSymbolRefAttr.get(fn),
|
||||
[*args_op, *captured_ops])
|
||||
flat_results = call.results
|
||||
|
||||
if ordered:
|
||||
raise NotImplementedError(
|
||||
"ordered=True is not supported in the jitted context without"
|
||||
" `call_tf_graph=True`"
|
||||
)
|
||||
|
||||
outputs = []
|
||||
for op, res_type in zip(flat_results, callee_result_types):
|
||||
if not res_type.has_static_shape:
|
||||
msg = (
|
||||
"Compiled TensorFlow function has dynamic output shape "
|
||||
+ f"{res_type}. call_tf can used in a staged context (under jax.jit,"
|
||||
" lax.scan, etc.) only with compilable functions with static"
|
||||
" output shapes. See"
|
||||
" https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf"
|
||||
" for a discussion."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
res_dtype = _mlir_type_to_numpy_dtype(res_type.element_type)
|
||||
# Canonicalize the results; e.g., makes them x32 if JAX is in 32-bit mode
|
||||
jax_res_dtype = dtypes.canonicalize_dtype(res_dtype)
|
||||
if res_dtype != jax_res_dtype:
|
||||
op = hlo.ConvertOp(
|
||||
mlir.aval_to_ir_type(core.ShapedArray(res_type.shape, jax_res_dtype)),
|
||||
op,
|
||||
).result
|
||||
outputs.append(op)
|
||||
return outputs
|
||||
|
||||
|
||||
def _register_call_lowering(platform):
|
||||
mlir.register_lowering(call_tf_p, functools.partial(_call_tf_lowering,
|
||||
platform=platform),
|
||||
platform=platform)
|
||||
for platform in ("cpu", "cuda", "tpu"):
|
||||
_register_call_lowering(platform)
|
||||
|
||||
|
||||
def emit_tf_embedded_graph_custom_call(
|
||||
ctx: mlir.LoweringRuleContext,
|
||||
concrete_function_flat_tf,
|
||||
operands: Sequence[ir.Value],
|
||||
has_side_effects,
|
||||
ordered,
|
||||
output_avals,
|
||||
):
|
||||
"""Emits a custom call referencing a tf.Graph embedding of the TF function.
|
||||
|
||||
All call_tf called function information is stored in tf.metadata.
|
||||
This includes:
|
||||
(1) The called function name: This name will be used by the runtime to execute
|
||||
the callback.
|
||||
(2) The called function index in the XLACallModule `function_list` attribute.
|
||||
"""
|
||||
call_tf_concrete_function_list = jax2tf_internal.get_thread_local_state_call_tf_concrete_function_list()
|
||||
if call_tf_concrete_function_list is None:
|
||||
raise ValueError(
|
||||
"call_tf_graph=True only support exporting by jax2tf.convert currently."
|
||||
)
|
||||
# TODO(necula): It is dangerous to modify global state when lowering because
|
||||
# there are a number of lowering caches that only cache the StableHLO.
|
||||
# See call_tf_test.py:test_multi_platform_call_tf_graph.
|
||||
called_index = add_to_call_tf_concrete_function_list(
|
||||
concrete_function_flat_tf, call_tf_concrete_function_list)
|
||||
tf_backend_config = {
|
||||
"has_token_input_output": ir.BoolAttr.get(ordered),
|
||||
"called_index": mlir.i64_attr(called_index),
|
||||
}
|
||||
result_avals = ctx.avals_out if ctx.avals_out is not None else ()
|
||||
|
||||
operands = list(operands)
|
||||
result_types = list(
|
||||
mlir.flatten_ir_types([mlir.aval_to_ir_type(aval) for aval in result_avals])
|
||||
)
|
||||
if ordered:
|
||||
operands.insert(0, ctx.tokens_in.get(call_tf_ordered_effect))
|
||||
result_types.insert(0, mlir.token_type())
|
||||
|
||||
custom_call = hlo.CustomCallOp(
|
||||
result_types,
|
||||
operands,
|
||||
call_target_name=ir.StringAttr.get("tf.call_tf_function"),
|
||||
has_side_effect=ir.BoolAttr.get(has_side_effects),
|
||||
api_version=mlir.i32_attr(2),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
backend_config=ir.StringAttr.get(""),
|
||||
)
|
||||
# Store TF metadata in unregistered attribute
|
||||
custom_call.attributes["tf.backend_config"] = ir.DictAttr.get(
|
||||
tf_backend_config
|
||||
)
|
||||
|
||||
results = list(custom_call.results)
|
||||
if ordered:
|
||||
token = results.pop(0)
|
||||
ctx.set_tokens_out(mlir.TokenSet({call_tf_ordered_effect: token}))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def add_to_call_tf_concrete_function_list(concrete_tf_fn: Any, call_tf_concrete_function_list: list[Any]) -> int:
|
||||
try:
|
||||
called_index = call_tf_concrete_function_list.index(concrete_tf_fn)
|
||||
except ValueError:
|
||||
called_index = len(call_tf_concrete_function_list)
|
||||
call_tf_concrete_function_list.append(concrete_tf_fn)
|
||||
return called_index
|
||||
|
||||
# Register a roofline call so that users can use roofline on functions that
|
||||
# contain call_tf. We register roofline in this file (instead of within the
|
||||
# roofline module) to avoid having to import jax2tf in roofline.
|
||||
roofline.register_standard_roofline(call_tf_p)
|
||||
@@ -0,0 +1,915 @@
|
||||
# Copyright 2020 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Provides JAX and TensorFlow interoperation APIs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
import contextlib
|
||||
import math
|
||||
import os
|
||||
import threading
|
||||
from typing import Any, Union
|
||||
import warnings
|
||||
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import tree_util
|
||||
from jax import export
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import api_util
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import op_shardings
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src.export import _export
|
||||
from jax._src.export import shape_poly
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
# These don't have public equivalents.
|
||||
from tensorflow.compiler.tf2xla.python import xla as tfxla
|
||||
from tensorflow.compiler.xla import xla_data_pb2
|
||||
try:
|
||||
from tensorflow.python.compiler.xla.experimental import xla_sharding
|
||||
except ModuleNotFoundError:
|
||||
# This can be removed when TF 2.10 support is no longer needed.
|
||||
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
|
||||
from tensorflow.python.eager import context as tf_context
|
||||
|
||||
NameStack = source_info_util.NameStack
|
||||
PolyShape = shape_poly.PolyShape # TODO: deprecate
|
||||
DType = Any
|
||||
|
||||
DisabledSafetyCheck = export.DisabledSafetyCheck
|
||||
|
||||
|
||||
map = util.safe_map
|
||||
zip = util.safe_zip
|
||||
|
||||
# A value suitable in a TF tracing context: tf.Tensor, tf.Variable,
|
||||
# or Python scalar or numpy.ndarray. (A tf.EagerTensor is a tf.Tensor.)
|
||||
TfVal = Any
|
||||
PrecisionType = int # Enum xla_data.PrecisionConfig.Precision
|
||||
|
||||
def _is_tfval(v: TfVal) -> bool:
|
||||
if isinstance(v, (tf.Tensor, tf.Variable)):
|
||||
return True
|
||||
try:
|
||||
# Include all convertible types, even if not supported on accelerators.
|
||||
with tf.device("CPU"):
|
||||
tf.constant(v)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
class _DefaultNativeSerialization:
|
||||
pass
|
||||
DEFAULT_NATIVE_SERIALIZATION = _DefaultNativeSerialization()
|
||||
|
||||
|
||||
# In order to ensure that JAX picks up the proper user-frame for source
|
||||
# locations we will register the TensorFlow source path as an internal
|
||||
# path with source_info_util. The typical stack when a JAX primitive
|
||||
# conversion happens is:
|
||||
# jax2tf.process_primitive (top of stack)
|
||||
# jax tracing machinery ...
|
||||
# tf.custom_gradient machinery ...
|
||||
# jax2tf.converted_fun
|
||||
# tf function machinery ...
|
||||
# user code invokes the converted function on TF tensors
|
||||
#
|
||||
# We need to skip over not only JAX internal frames, but TF internal frames
|
||||
# also.
|
||||
# We register the TensorFlow source path lazily
|
||||
_has_registered_tf_source_path = False
|
||||
|
||||
class _ThreadLocalState(threading.local):
|
||||
def __init__(self):
|
||||
# Keep track if we are inside a call_tf. In that context we disable the
|
||||
# safety check that we are not inside JAX transformations.
|
||||
self.inside_call_tf = False
|
||||
|
||||
# Maps dimension variables to TF expressions, for non-native lowering
|
||||
self.shape_env: Sequence[tuple[str, TfVal]] = ()
|
||||
|
||||
# A dict collecting all tf concrete_functions called by stablehlo.custom_call
|
||||
# This is used only by native serialization (unlike all the other
|
||||
# thread-local state).
|
||||
self.call_tf_concrete_function_list: list[Any] | None = None
|
||||
|
||||
_thread_local_state = _ThreadLocalState()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def inside_call_tf():
|
||||
# Set the inside_call_tf flag for a context.
|
||||
prev = _thread_local_state.inside_call_tf
|
||||
_thread_local_state.inside_call_tf = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_thread_local_state.inside_call_tf = prev
|
||||
|
||||
|
||||
def get_thread_local_state_call_tf_concrete_function_list() -> (
|
||||
list[Any] | None
|
||||
):
|
||||
return _thread_local_state.call_tf_concrete_function_list
|
||||
|
||||
|
||||
@partial(api_util.api_hook, tag="jax2tf_convert")
|
||||
def convert(fun_jax: Callable,
|
||||
*,
|
||||
polymorphic_shapes: str | PolyShape | None | Sequence[str | PolyShape | None] = None,
|
||||
polymorphic_constraints: Sequence[str] = (),
|
||||
with_gradient: bool = True,
|
||||
enable_xla: bool | _DefaultNativeSerialization = DEFAULT_NATIVE_SERIALIZATION,
|
||||
native_serialization: bool | _DefaultNativeSerialization = DEFAULT_NATIVE_SERIALIZATION,
|
||||
native_serialization_platforms: Sequence[str] | None = None,
|
||||
native_serialization_disabled_checks: Sequence[DisabledSafetyCheck] = (),
|
||||
) -> Callable:
|
||||
"""Allows calling a JAX function from a TensorFlow program.
|
||||
|
||||
See
|
||||
[README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md)
|
||||
for more details about usage and common problems.
|
||||
|
||||
Args:
|
||||
fun_jax: target JAX function to be called. Its arguments and return value
|
||||
should be JAX arrays, or nested standard Python containers
|
||||
(tuple/list/dict) thereof (pytrees).
|
||||
polymorphic_shapes: Specifies input shapes to be treated polymorphically
|
||||
during lowering.
|
||||
|
||||
.. warning:: The shape-polymorphic lowering is an experimental feature.
|
||||
It is meant to be sound, but it is known to reject some JAX programs
|
||||
that are shape polymorphic. The details of this feature can change.
|
||||
|
||||
It should be `None` (all arguments are monomorphic), a single PolyShape
|
||||
or string (applies to all arguments), or a tuple/list of the same length
|
||||
as the function arguments. For each argument the shape specification
|
||||
should be `None` (monomorphic argument), or a Python object with the
|
||||
same pytree structure as the argument.
|
||||
See [how optional parameters are matched to
|
||||
arguments](https://docs.jax.dev/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).
|
||||
|
||||
A shape specification for an array argument should be an object
|
||||
`PolyShape(dim0, dim1, ..., dimn)`
|
||||
where each `dim` is a dimension specification: a positive integer denoting
|
||||
a monomorphic dimension of the given size, or a string denoting a
|
||||
dimension variable assumed to range over non-zero dimension sizes, or
|
||||
the special placeholder string "_" denoting a monomorphic dimension
|
||||
whose size is given by the actual argument. As a shortcut, an Ellipsis
|
||||
suffix in the list of dimension specifications stands for a list of "_"
|
||||
placeholders.
|
||||
|
||||
For convenience, a shape specification can also be given as a string
|
||||
representation, e.g.: "batch, ...", "batch, height, width, _", possibly
|
||||
with surrounding parentheses: "(batch, ...)".
|
||||
|
||||
The lowering fails if it cannot ensure that the it would produce the same
|
||||
sequence of TF ops for any non-zero values of the dimension variables.
|
||||
|
||||
polymorphic_shapes are only supported for positional arguments; shape
|
||||
polymorphism is not supported for keyword arguments.
|
||||
|
||||
See [the README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
|
||||
for more details.
|
||||
|
||||
polymorphic_constraints: a sequence of constraints on symbolic dimension
|
||||
expressions, of the form `e1 >= e2` or `e1 <= e2`.
|
||||
See more details at https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
|
||||
with_gradient: if set (default), add a tf.custom_gradient to the lowered
|
||||
function, by converting the ``jax.vjp(fun)``. This means that reverse-mode
|
||||
TensorFlow AD is supported for the output TensorFlow function, and the
|
||||
value of the gradient will be JAX-accurate.
|
||||
native_serialization_platforms: Specifies the platform(s)
|
||||
for which to lower the code. Must be a tuple of
|
||||
strings, including a subset of: 'cpu', 'cuda', 'rocm', 'tpu'.
|
||||
The default (`None``), specifies the JAX default
|
||||
backend on the machine where the lowering is done.
|
||||
native_serialization_disabled_checks: Disables the specified safety checks.
|
||||
See docstring of `DisabledSafetyCheck`.
|
||||
|
||||
Returns:
|
||||
A version of `fun_jax` that expects TfVals as arguments (or
|
||||
tuple/lists/dicts thereof), and returns TfVals as outputs, and uses
|
||||
only TensorFlow ops and thus can be called from a TensorFlow program.
|
||||
"""
|
||||
if native_serialization is not DEFAULT_NATIVE_SERIALIZATION:
|
||||
warnings.warn(
|
||||
"The `native_serialization` parameter is deprecated and "
|
||||
"will be removed in a future version of JAX.",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
del native_serialization
|
||||
if enable_xla is not DEFAULT_NATIVE_SERIALIZATION:
|
||||
warnings.warn(
|
||||
"The `enable_xla` parameter is deprecated and "
|
||||
"will be removed in a future version of JAX.",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
del enable_xla
|
||||
|
||||
if native_serialization_platforms:
|
||||
if (not isinstance(native_serialization_platforms, (list, tuple)) or
|
||||
not all(p in ["cpu", "cuda", "rocm", "tpu"]
|
||||
for p in native_serialization_platforms)):
|
||||
raise ValueError(
|
||||
"native_serialization_platforms must be a sequence "
|
||||
"containing a subset of {'cpu', 'cuda', 'rocm', 'tpu'}. "
|
||||
f"Got: {native_serialization_platforms}")
|
||||
native_serialization_platforms = tuple(native_serialization_platforms)
|
||||
|
||||
api.check_callable(fun_jax)
|
||||
|
||||
def converted_fun_tf(*args_tf: TfVal, **kwargs_tf: TfVal) -> TfVal:
|
||||
|
||||
# TODO: is there a better way to check if we are inside a transformation?
|
||||
if not core.trace_state_clean() and not _thread_local_state.inside_call_tf:
|
||||
# It is Ok to nest convert when we are inside a call_tf
|
||||
raise ValueError(
|
||||
"convert must be used outside all JAX transformations." +
|
||||
f"Trace state: {core.trace_ctx}")
|
||||
|
||||
global _has_registered_tf_source_path
|
||||
if not _has_registered_tf_source_path:
|
||||
source_info_util.register_exclusion(os.path.dirname(tf.__file__))
|
||||
_has_registered_tf_source_path = True
|
||||
|
||||
def jax_arg_spec_from_tf(a: TfVal) -> jax.ShapeDtypeStruct:
|
||||
# The shape and JAX dtype for a TF argument
|
||||
tf_arg_shape = np.shape(a)
|
||||
# Fix the shape for TF1
|
||||
tf_arg_shape = tuple(d.value
|
||||
if isinstance(d, tf.compat.v1.Dimension) else d
|
||||
for d in tf_arg_shape)
|
||||
_, a_jax_dtype = _tfval_to_tensor_jax_dtype(a)
|
||||
# We count on the fact that jax.ShapeDtypeStruct allows shapes that
|
||||
# contain None.
|
||||
return jax.ShapeDtypeStruct(tf_arg_shape, a_jax_dtype)
|
||||
|
||||
args_jax_specs = tree_util.tree_map(jax_arg_spec_from_tf, args_tf)
|
||||
args_specs = export.symbolic_args_specs(
|
||||
args_jax_specs, polymorphic_shapes,
|
||||
constraints=polymorphic_constraints)
|
||||
# The polymorphic_shapes argument refers to positional arguments only.
|
||||
# We assume None for the kwargs.
|
||||
kwargs_jax_specs = tree_util.tree_map(jax_arg_spec_from_tf, kwargs_tf)
|
||||
kwargs_specs = export.symbolic_args_specs(
|
||||
kwargs_jax_specs, None)
|
||||
combined_args_tf = (args_tf, kwargs_tf)
|
||||
args_flat_tf: Sequence[TfVal]
|
||||
args_flat_tf, args_kwargs_tree = tree_util.tree_flatten(combined_args_tf)
|
||||
|
||||
args_flat_tf = tuple(
|
||||
map(preprocess_arg_tf, range(len(args_flat_tf)), args_flat_tf))
|
||||
|
||||
impl = NativeSerializationImpl(
|
||||
fun_jax,
|
||||
args_specs=args_specs, kwargs_specs=kwargs_specs,
|
||||
native_serialization_platforms=native_serialization_platforms,
|
||||
native_serialization_disabled_checks=native_serialization_disabled_checks)
|
||||
try:
|
||||
impl.before_conversion()
|
||||
|
||||
outs_tree: tree_util.PyTreeDef | None = None
|
||||
if with_gradient:
|
||||
@tf.custom_gradient
|
||||
def converted_fun_flat_with_custom_gradient_tf(*args_flat_tf: TfVal) -> TfVal:
|
||||
nonlocal outs_tree
|
||||
outs_tf, outs_avals, outs_tree = impl.run_fun_tf(args_flat_tf)
|
||||
return (tuple(outs_tf),
|
||||
_make_custom_gradient_fn_tf(
|
||||
fun_jax,
|
||||
impl=impl,
|
||||
with_gradient=with_gradient,
|
||||
args_specs=args_specs, kwargs_specs=kwargs_specs,
|
||||
args_tf=args_flat_tf,
|
||||
outs_avals=outs_avals,
|
||||
outs_tf=outs_tf))
|
||||
|
||||
outs_flat_tf = converted_fun_flat_with_custom_gradient_tf(*args_flat_tf)
|
||||
else:
|
||||
outs_tf, _, outs_tree = impl.run_fun_tf(args_flat_tf)
|
||||
message = ("The jax2tf-converted function does not support gradients. "
|
||||
"Use `with_gradient` parameter to enable gradients")
|
||||
# We use PreventGradient, which is propagated through a SavedModel.
|
||||
outs_flat_tf = [
|
||||
tf.raw_ops.PreventGradient(input=o, message=message)
|
||||
for o in outs_tf
|
||||
]
|
||||
finally:
|
||||
impl.after_conversion()
|
||||
|
||||
assert outs_tree is not None
|
||||
outs_flat_tf = [tf.identity(x, "jax2tf_out") for x in outs_flat_tf]
|
||||
out_tf = tree_util.tree_unflatten(outs_tree, outs_flat_tf)
|
||||
return out_tf
|
||||
|
||||
return converted_fun_tf
|
||||
|
||||
|
||||
class NativeSerializationImpl:
|
||||
def __init__(self, fun_jax, *,
|
||||
args_specs, kwargs_specs,
|
||||
native_serialization_platforms: Sequence[str] | None,
|
||||
native_serialization_disabled_checks: Sequence[DisabledSafetyCheck]):
|
||||
self.convert_kwargs = dict(native_serialization_platforms=native_serialization_platforms,
|
||||
native_serialization_disabled_checks=native_serialization_disabled_checks)
|
||||
if hasattr(fun_jax, "trace"):
|
||||
# If we have a pjit or pmap already we do not wrap with another, and we
|
||||
# allow shardings.
|
||||
fun_jit = fun_jax
|
||||
else:
|
||||
# We support convert(pjit(f_jax)) and convert(jit(f_jax)) but also
|
||||
# convert(f_jax), in which case a "jit" is implied. In that case we raise
|
||||
# an error if the lowered function contains non-replicated sharding annotations.
|
||||
fun_jit = jax.jit(fun_jax)
|
||||
self.fun_jax = fun_jit
|
||||
self.args_specs = args_specs
|
||||
self.kwargs_specs = kwargs_specs
|
||||
self.native_serialization_disabled_checks = native_serialization_disabled_checks
|
||||
self.native_serialization_platforms = native_serialization_platforms
|
||||
|
||||
def before_conversion(self):
|
||||
_prev_func_list = _thread_local_state.call_tf_concrete_function_list
|
||||
_thread_local_state.call_tf_concrete_function_list = []
|
||||
|
||||
def _restore_context():
|
||||
_thread_local_state.call_tf_concrete_function_list = _prev_func_list
|
||||
|
||||
self._restore_context = _restore_context
|
||||
_exported_device_assignment = [None]
|
||||
self.exported = _export._export_internal(
|
||||
self.fun_jax,
|
||||
platforms=self.native_serialization_platforms,
|
||||
disabled_checks=self.native_serialization_disabled_checks,
|
||||
_device_assignment_for_internal_jax2tf_use_only=_exported_device_assignment,
|
||||
)(*self.args_specs, **self.kwargs_specs)
|
||||
assert(_exported_device_assignment[0] is not None)
|
||||
self.device_assignment = _exported_device_assignment[0]
|
||||
|
||||
def after_conversion(self):
|
||||
self._restore_context()
|
||||
|
||||
def run_fun_tf(self,
|
||||
args_flat_tf: Sequence[TfVal]
|
||||
) -> tuple[Sequence[TfVal], Sequence[core.ShapedArray], tree_util.PyTreeDef]:
|
||||
results = _run_exported_as_tf(args_flat_tf, self.exported)
|
||||
return results, tuple(self.exported.out_avals), self.exported.out_tree
|
||||
|
||||
def get_vjp_fun(self) -> tuple[Callable,
|
||||
Sequence[core.AbstractValue]]:
|
||||
return _export._get_vjp_fun(self.fun_jax,
|
||||
in_tree=self.exported.in_tree,
|
||||
in_avals=self.exported.in_avals,
|
||||
has_named_shardings=self.exported._has_named_shardings,
|
||||
in_named_shardings=self.exported._in_named_shardings,
|
||||
out_named_shardings=self.exported._out_named_shardings,
|
||||
in_shardings_hlo=self.exported.in_shardings_hlo,
|
||||
out_avals=self.exported.out_avals,
|
||||
out_shardings_hlo=self.exported.out_shardings_hlo,
|
||||
device_assignment=self.device_assignment,
|
||||
apply_jit=True)
|
||||
|
||||
|
||||
def dtype_of_val(val: TfVal) -> DType:
|
||||
"""Computes the TensorFlow dtype using JAX's typing rules.
|
||||
|
||||
If the value is a tf.Tensor, it starts with its dtype. If the value is a
|
||||
constant it uses JAX to infer its dtype. The resulting dtype follows the
|
||||
JAX type inference rules, and depends on the value of the
|
||||
JAX_ENABLE_X64 flag.
|
||||
|
||||
See README.md for how 64-bit values are treated.
|
||||
"""
|
||||
tval, _ = _tfval_to_tensor_jax_dtype(val)
|
||||
return tval.dtype
|
||||
|
||||
@partial(api_util.api_hook, tag="jax2tf_eval_polymorphic_shapes")
|
||||
def eval_polymorphic_shape(fun_jax: Callable,
|
||||
*,
|
||||
polymorphic_shapes=None) -> Callable:
|
||||
"""Evaluates the output shape in presence of shape polymorphism.
|
||||
|
||||
This is done without lowering or executing the function, same as for
|
||||
`jax.eval_shape`.
|
||||
|
||||
Args:
|
||||
fun_jax: target JAX function to be called. Its arguments and return value
|
||||
should be JAX arrays, or nested standard Python containers
|
||||
(tuple/list/dict) thereof (pytrees).
|
||||
polymorphic_shapes: Specifies input shapes to be treated polymorphically
|
||||
during shape evaluation. See discussion for `jax2tf.convert`.
|
||||
|
||||
.. warning:: The shape-polymorphic lowering is an experimental feature.
|
||||
|
||||
Returns: a function that takes `jax.ShapeDtypeStruct`s (or any values
|
||||
with `.shape` and `.dtype` attributes) corresponding to the inputs for
|
||||
`fun_jax`, and returns a tuple with:
|
||||
|
||||
* the jax.ShapeDtypeStruct corresponding to the result, as for
|
||||
`jax.eval_shape`. The shape may contain symbolic dimension expressions.
|
||||
* the value that can be passed to `polymorphic_shapes` for a subsequent
|
||||
call to `jax2tf.eval_polymorphic_shape`, or `jax2tf.convert`.
|
||||
|
||||
For example:
|
||||
|
||||
>>> import jax
|
||||
>>> from jax.experimental import jax2tf
|
||||
>>> from jax import numpy as jnp
|
||||
>>>
|
||||
>>> f = lambda A, x: jnp.sin(jnp.dot(A, x))
|
||||
>>> A = jax.ShapeDtypeStruct((2000, 3000), jnp.float32)
|
||||
>>> x = jax.ShapeDtypeStruct((3000, 1000), jnp.float32)
|
||||
>>> out_spec, out_poly_shape = jax2tf.eval_polymorphic_shape(f, polymorphic_shapes=["a, b", "b, c"])(A, x)
|
||||
>>> print(out_spec.shape)
|
||||
("a", "c")
|
||||
>>> print(out_poly_shape)
|
||||
(a, c)
|
||||
>>> res_spec, res_poly_shape = jax2tf.eval_polymorphic_shape(lambda x: x.T, polymorphic_shapes=[out_poly_shape])(out_spec)
|
||||
>>> print(res_poly_shape)
|
||||
(c, a)
|
||||
"""
|
||||
def do_eval_polymorphic_shape(*args_specs) -> Any:
|
||||
args_poly_specs = export.symbolic_args_specs(
|
||||
args_specs, polymorphic_shapes)
|
||||
res_poly_spec = jax.eval_shape(fun_jax, *args_poly_specs)
|
||||
# TODO(necula): For now we export the polymorphic shapes using `str`.
|
||||
res_polymorphic_shape = tree_util.tree_map(lambda r: str(r.shape), res_poly_spec)
|
||||
return res_poly_spec, res_polymorphic_shape
|
||||
|
||||
return do_eval_polymorphic_shape
|
||||
|
||||
|
||||
# Internals
|
||||
|
||||
def preprocess_arg_tf(arg_idx: int,
|
||||
arg_tf: TfVal) -> TfVal:
|
||||
"""Pre-processes the TF args.
|
||||
|
||||
Returns: a tuple with the pre-processed TF arg, the TF shape, and the
|
||||
JAX dtype.
|
||||
"""
|
||||
if not _is_tfval(arg_tf):
|
||||
msg = (f"Argument {arg_tf} of type {type(arg_tf)} of jax2tf.convert(f) should "
|
||||
"be NumPy array, scalar, tf.Variable, or tf.Tensor")
|
||||
raise TypeError(msg)
|
||||
|
||||
# May cast the args_flat to JAX types, using JAX's interpretation
|
||||
# of types of constants.
|
||||
arg_tf, _ = _tfval_to_tensor_jax_dtype(arg_tf)
|
||||
# Name input tensors; do this after we have cast the arguments
|
||||
arg_tf = tf.identity(arg_tf, f"jax2tf_arg_{arg_idx}")
|
||||
return arg_tf
|
||||
|
||||
|
||||
def _make_custom_gradient_fn_tf(fun_jax,
|
||||
*,
|
||||
impl: NativeSerializationImpl,
|
||||
with_gradient: bool,
|
||||
args_specs, kwargs_specs,
|
||||
args_tf: Sequence[TfVal],
|
||||
outs_avals: Sequence[core.ShapedArray],
|
||||
outs_tf: Sequence[TfVal]):
|
||||
"""Prepares the TF function to be used with tf.custom_gradient.
|
||||
|
||||
Args:
|
||||
impl: the serialization implementation details
|
||||
with_gradient: whether to include a tf.custom_gradient
|
||||
args_specs, kwargs_specs: the jax.ShapeDtypeArrays for the args and kwargs
|
||||
args_tf: the flattened TF arguments of the primal function
|
||||
outs_avals: the flattened output JAX abstract values of the primal function
|
||||
outs_tf: the flattened TF outputs of the primal function
|
||||
"""
|
||||
def grad_fn_tf(*out_cts_flat_tf: TfVal,
|
||||
variables=None):
|
||||
if variables:
|
||||
raise ValueError(
|
||||
"Unexpected variables used in forward pass. "
|
||||
"This should not happen for first-order differentiation. "
|
||||
f"{variables=}")
|
||||
|
||||
# TODO: enable higher-order gradients
|
||||
with tf.name_scope("jax2tf_vjp"):
|
||||
def fix_out_ct(out_ct_tf, out_ct_aval: core.ShapedArray, out_tf: TfVal):
|
||||
# If the primal function has outputs of integer or bool types, and if we are
|
||||
# under a tf.function context, then TF will pass None in _out_cts_flat
|
||||
# in place of these values. We should change these to float0 or
|
||||
# else JAX gets unhappy. See issue #6975.
|
||||
if out_ct_tf is not None:
|
||||
return out_ct_tf
|
||||
assert core.primal_dtype_to_tangent_dtype(out_ct_aval.dtype) == dtypes.float0, f"{out_ct_tf=}"
|
||||
# Note that out_ct_aval.shape contains dimension variable from the
|
||||
# primal function scope. We use tf.zeros_like to make a 0 of the right shape.
|
||||
return tf.zeros_like(out_tf, dtype=_tf_np_dtype_for_float0)
|
||||
|
||||
out_cts_fixed_flat_tf = tuple(map(fix_out_ct, out_cts_flat_tf, outs_avals, outs_tf))
|
||||
vjp_args_flat_tf = tuple(args_tf) + out_cts_fixed_flat_tf
|
||||
|
||||
fun_vjp_jax, vjp_in_avals = impl.get_vjp_fun()
|
||||
|
||||
vjp_polymorphic_shapes = tuple(
|
||||
# pyrefly: ignore[missing-attribute]
|
||||
str(a.shape) # Note: may be _DimExpr, not just DimVar
|
||||
for a in vjp_in_avals)
|
||||
in_cts_flat = convert(
|
||||
fun_vjp_jax,
|
||||
with_gradient=with_gradient,
|
||||
polymorphic_shapes=vjp_polymorphic_shapes,
|
||||
**impl.convert_kwargs)(*vjp_args_flat_tf)
|
||||
|
||||
# We do not need to fix the in_cts because the TF gradient machinery
|
||||
# will adjust the unconnected gradients and those for integer types.
|
||||
return in_cts_flat
|
||||
|
||||
return grad_fn_tf
|
||||
|
||||
|
||||
def _run_exported_as_tf(args_flat_tf: Sequence[TfVal],
|
||||
exported: export.Exported,
|
||||
) -> Sequence[TfVal]:
|
||||
"""Runs the `exported` as an XlaCallModule TF op.
|
||||
|
||||
Returns: the flattened tuple of results.
|
||||
"""
|
||||
args_avals = exported.in_avals
|
||||
|
||||
# TF values may be integer types for float0
|
||||
def _convert_value(val, aval):
|
||||
# Check the shape
|
||||
assert all(d_aval == d_val
|
||||
for d_aval, d_val in zip(aval.shape, val.shape)
|
||||
if core.is_constant_dim(d_aval)), (aval, val)
|
||||
conversion_dtype = _to_tf_dtype(aval.dtype)
|
||||
if conversion_dtype != aval.dtype:
|
||||
return tf.cast(val, conversion_dtype)
|
||||
else:
|
||||
return val
|
||||
|
||||
args_flat_tf = tuple(map(_convert_value, args_flat_tf, args_avals))
|
||||
|
||||
out_shapes_tf = tuple(
|
||||
tuple(d if core.is_constant_dim(d) else None
|
||||
for d in out_aval.shape)
|
||||
for out_aval in exported.out_avals)
|
||||
|
||||
out_types = tuple(_to_tf_dtype(out_aval.dtype) for out_aval in exported.out_avals)
|
||||
|
||||
kept_args_flat_tf = [atf for i, atf in enumerate(args_flat_tf) if i in exported.module_kept_var_idx]
|
||||
|
||||
version = exported.calling_convention_version
|
||||
|
||||
try:
|
||||
get_max_supported_version = tfxla.call_module_maximum_supported_version
|
||||
except AttributeError:
|
||||
get_max_supported_version = None
|
||||
|
||||
if get_max_supported_version:
|
||||
max_supported_version = get_max_supported_version()
|
||||
else:
|
||||
max_supported_version = 6
|
||||
|
||||
if version > max_supported_version:
|
||||
raise NotImplementedError(
|
||||
"XlaCallModule from your TensorFlow installation supports up to "
|
||||
f"serialization version {max_supported_version} but the serialized "
|
||||
f"module needs version {version}. "
|
||||
"You should upgrade TensorFlow, e.g., to tf_nightly."
|
||||
)
|
||||
|
||||
call_module_attrs: dict[str, Any] = dict(
|
||||
version=version,
|
||||
Tout=out_types,
|
||||
Sout=out_shapes_tf,
|
||||
function_list=[
|
||||
concrete_fn.function_def.signature.name
|
||||
for concrete_fn in _thread_local_state.call_tf_concrete_function_list
|
||||
] if _thread_local_state.call_tf_concrete_function_list is not None else [],
|
||||
# We always set has_token_input_output because it requires real tokens
|
||||
# for versions less than 9 and is not used starting with version 9.
|
||||
has_token_input_output=False
|
||||
)
|
||||
|
||||
call_module_attrs["platforms"] = tuple(p.upper() for p in exported.platforms)
|
||||
if version >= 6:
|
||||
call_module_attrs["disabled_checks"] = tuple(
|
||||
str(dc)
|
||||
for dc in exported.disabled_safety_checks)
|
||||
else:
|
||||
if version >= 3:
|
||||
if DisabledSafetyCheck.platform() in exported.disabled_safety_checks:
|
||||
call_module_attrs["platforms"] = () # No platform checking
|
||||
|
||||
if version >= 10:
|
||||
call_module_attrs["use_shardy_partitioner"] = (
|
||||
config.use_shardy_partitioner.value
|
||||
)
|
||||
|
||||
if logging.vlog_is_on(3):
|
||||
# We already logged the MLIR module when we exported it.
|
||||
logging.vlog(3, "XlaCallModule %s", str(call_module_attrs))
|
||||
|
||||
call_module_attrs["module"] = exported.mlir_module_serialized
|
||||
|
||||
# Apply the shardings on arguments and results for pjit. This is redundant
|
||||
# because the mlir_module_text will already contain the shardings, but it
|
||||
# makes it easier for tools like the TPU inference converter to see the
|
||||
# sharding without digging into the `module` attribute of the `XlaCallModule`
|
||||
# op, in the same way as it is done for the legacy jax2tf conversion.
|
||||
# Do not apply XlaSharding for REPLICATED, on inputs and outputs.
|
||||
# This is an agreed convention, and also improves usability under TF eager.
|
||||
# See b/255511660.
|
||||
kept_in_shardings = []
|
||||
for i in exported.module_kept_var_idx:
|
||||
if exported._has_named_shardings:
|
||||
in_sharding_hlo = _export.named_to_hlo_sharding(
|
||||
exported._in_named_shardings[i],
|
||||
exported.in_avals[i])
|
||||
else:
|
||||
in_sharding_hlo = exported.in_shardings_hlo[i]
|
||||
kept_in_shardings.append(in_sharding_hlo)
|
||||
args_flat_tf = tuple(
|
||||
map(partial(_shard_value,
|
||||
skip_replicated_sharding=tf.executing_eagerly()),
|
||||
kept_args_flat_tf, kept_in_shardings))
|
||||
res = tfxla.call_module(args_flat_tf, **call_module_attrs)
|
||||
# TODO(b/278940799): Replace the TF v1 API with public TF2 API.
|
||||
# Add the custom call tf.function into the default graph, so those functions
|
||||
# will be available during tf.SavedModel.save.
|
||||
if _thread_local_state.call_tf_concrete_function_list is not None:
|
||||
for concrete_fn in _thread_local_state.call_tf_concrete_function_list:
|
||||
tf.compat.v1.get_default_graph()._add_function_recursive(
|
||||
concrete_fn._inference_function
|
||||
)
|
||||
|
||||
if exported._has_named_shardings:
|
||||
out_shardings_hlo = tuple(
|
||||
_export.named_to_hlo_sharding(s, a)
|
||||
for s, a in zip(exported._out_named_shardings, exported.out_avals))
|
||||
else:
|
||||
out_shardings_hlo = exported.out_shardings_hlo
|
||||
res = list(map(partial(_shard_value,
|
||||
skip_replicated_sharding=tf.executing_eagerly()),
|
||||
res, out_shardings_hlo))
|
||||
res = tuple(map(_convert_value, res, exported.out_avals))
|
||||
return res
|
||||
|
||||
|
||||
def _jax_physical_aval(aval: core.ShapedArray) -> core.ShapedArray:
|
||||
"""Converts JAX avals from logical to physical, if relevant.
|
||||
|
||||
JAX might have avals whose logical vs physical shape/dtype may
|
||||
differ, and only the physical view is expected to possibly
|
||||
relate to TF. TF impl rules should operate on the physical form.
|
||||
|
||||
A JAX logical aval might even correspond, in principle, to several
|
||||
physical avals, but we don't support those here. Instead we assert
|
||||
there is only one and return it.
|
||||
"""
|
||||
physical_aval = core.physical_aval(aval)
|
||||
assert (len(physical_aval.shape) >= len(aval.shape) and
|
||||
physical_aval.shape[:len(aval.shape)] == aval.shape), (physical_aval, aval)
|
||||
return physical_aval
|
||||
|
||||
def _jax_physical_dtype(dtype):
|
||||
# assuming () is a fine stand-in shape
|
||||
return _jax_physical_aval(core.ShapedArray((), dtype)).dtype
|
||||
|
||||
|
||||
def _aval_to_tf_shape(aval: core.ShapedArray) -> tuple[int | None, ...]:
|
||||
|
||||
"""Generate a TF shape, possibly containing None for polymorphic dimensions."""
|
||||
aval = _jax_physical_aval(aval)
|
||||
return tuple(map(lambda d: None if export.is_symbolic_dim(d) else d,
|
||||
aval.shape))
|
||||
|
||||
# In the TF world, we represent float0 as zeros of this type.
|
||||
# We pick bool because this is what JAX uses when it lowers float0 to HLO.
|
||||
_tf_np_dtype_for_float0 = np.bool_
|
||||
|
||||
def _to_tf_dtype(jax_dtype):
|
||||
# Note that converting _to_tf_dtype and _to_jax_dtype are not inverses,
|
||||
# due to float0 and 64-bit behavior.
|
||||
try:
|
||||
jax_dtype = _jax_physical_dtype(jax_dtype)
|
||||
except TypeError:
|
||||
# `jax_dtype` isn't actually a valid jax dtype (e.g. it is
|
||||
# tf.float32), so there is no physical dtype anyway
|
||||
pass
|
||||
if jax_dtype == dtypes.float0:
|
||||
jax_dtype = _tf_np_dtype_for_float0
|
||||
return tf.dtypes.as_dtype(jax_dtype)
|
||||
|
||||
|
||||
def _to_jax_dtype(tf_dtype):
|
||||
# Note that converting _to_tf_dtype and _to_jax_dtype are not inverses,
|
||||
# due to float0 and 64-bit behavior.
|
||||
dt = dtypes.canonicalize_dtype(tf_dtype.as_numpy_dtype)
|
||||
if dt not in dtypes._jax_dtype_set:
|
||||
raise TypeError(f"dtype {dt} is not a valid JAX array "
|
||||
"type. Only arrays of numeric types are supported by JAX.")
|
||||
return dt
|
||||
|
||||
|
||||
def _tfval_to_tensor_jax_dtype(val: TfVal,
|
||||
jax_dtype: DType | None = None,
|
||||
memoize_constants=False) -> tuple[TfVal, DType]:
|
||||
"""Converts a scalar, ndarray, or tf.Tensor to a tf.Tensor with proper type.
|
||||
|
||||
If `jax_dtype` is missing, uses JAX typing rules.
|
||||
See README.md for details regarding 64-bit values.
|
||||
|
||||
Args:
|
||||
val: a scalar, ndarray, tf.Tensor, or tf.Variable
|
||||
jax_dtype: an optional dtype to use. If missing, uses JAX type inference
|
||||
rules for constants.
|
||||
memoize_constants: whether to memoize TF constants. We can't do this
|
||||
everywhere, we may be outside of a conversion scope.
|
||||
|
||||
Returns:
|
||||
a tuple with a tf.Tensor with the type as needed by JAX, and the JAX type.
|
||||
"""
|
||||
if isinstance(val, (tf.Tensor, tf.Variable)):
|
||||
jax_dtype = jax_dtype or _to_jax_dtype(val.dtype) # Give JAX a chance to pick the type
|
||||
conversion_dtype = _to_tf_dtype(jax_dtype)
|
||||
if conversion_dtype != val.dtype: # May need to cast for 64-bit values
|
||||
return tf.cast(val, conversion_dtype), jax_dtype
|
||||
else:
|
||||
return val, jax_dtype
|
||||
else: # A constant
|
||||
jax_dtype = jax_dtype or core.typeof(val).dtype
|
||||
# TODO(document): We assume that the value of a constant does not
|
||||
# change through the scope of the function. But it may be an ndarray, ...
|
||||
# JAX has the same problem when generating HLO.
|
||||
const_key = (id(val), jax_dtype)
|
||||
# Since we use id(val) as a cache key, we have to make sure that we keep
|
||||
# the previous `val` alive. Otherwise, for a ndarray, it can get garbage
|
||||
# collected and reused for a different value, which would create correctness
|
||||
# issues. We keep the `val` alive by storing in the cache the pair
|
||||
# `(val, tf_val)`.
|
||||
# Only memoize non-scalars. JAX will lift all non-scalar constants as
|
||||
# Jaxpr consts, to the top level of the Jaxpr. This ensures that we see them
|
||||
# early, when entering the Jaxpr, so we create the tf.const early and its
|
||||
# scope is the entire Jaxpr.
|
||||
do_memoize = (memoize_constants and np.size(val) > 1 and _thread_local_state.constant_cache is not None)
|
||||
if do_memoize:
|
||||
_, tf_val = _thread_local_state.constant_cache.get(const_key, (None, None))
|
||||
else:
|
||||
tf_val = None
|
||||
if tf_val is None:
|
||||
conversion_dtype = _to_tf_dtype(jax_dtype)
|
||||
# The float0 type is not known to TF.
|
||||
if jax_dtype == dtypes.float0:
|
||||
val = np.zeros(np.shape(val), conversion_dtype.as_numpy_dtype)
|
||||
if hasattr(val, 'dtype') and dtypes.issubdtype(val.dtype, dtypes.extended):
|
||||
val = val.dtype._rules.physical_const(val)
|
||||
tf_val = tf.convert_to_tensor(val, dtype=conversion_dtype)
|
||||
if do_memoize:
|
||||
_thread_local_state.constant_cache[const_key] = (val, tf_val)
|
||||
return tf_val, jax_dtype
|
||||
|
||||
|
||||
PartitionsOrReplicated = Union[tuple[int, ...], None]
|
||||
|
||||
def split_to_logical_devices(tensor: TfVal,
|
||||
partition_dimensions: PartitionsOrReplicated):
|
||||
"""Like TPUMPStrategy.experimental_split_to_logical_devices.
|
||||
|
||||
For jax2tf purposes we want to avoid needing to thread the `strategy` object
|
||||
through the generated computation. It seems that the original function needs
|
||||
the strategy object only for error checking, which we assume is done upstream
|
||||
by JAX.
|
||||
|
||||
Args:
|
||||
tensor: Input tensor to annotate.
|
||||
partition_dimensions: A list of integers, with one integer per tensor
|
||||
dimension, specifying in how many parts the dimension should be split. The
|
||||
product of integers must equal the number of devices per replica.
|
||||
use_sharding_op: whether to use a sharding op, or not.
|
||||
|
||||
Returns:
|
||||
an annotated tensor.
|
||||
"""
|
||||
# TODO: this is only for sharded_jit. Either remove, or implement in terms
|
||||
# of _shard_values.
|
||||
if partition_dimensions is None:
|
||||
return xla_sharding.replicate(tensor, use_sharding_op=True)
|
||||
num_partition_splits = math.prod(partition_dimensions)
|
||||
tile_assignment = np.arange(num_partition_splits).reshape(
|
||||
partition_dimensions)
|
||||
return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True)
|
||||
|
||||
|
||||
def _shard_value(val: TfVal,
|
||||
sd: xla_client.HloSharding | None, *,
|
||||
skip_replicated_sharding: bool) -> TfVal:
|
||||
"""Apply sharding to a TfVal."""
|
||||
if sd is None:
|
||||
return val
|
||||
|
||||
sharding_proto = sd.to_proto()
|
||||
if (skip_replicated_sharding and
|
||||
op_shardings.is_hlo_sharding_replicated(sd)):
|
||||
return val
|
||||
|
||||
# Tensorflow heavily relies on tile_assignment_devices proto fields specific
|
||||
# to V1 sharding format, falling back to this format.
|
||||
if (
|
||||
not sharding_proto.tile_assignment_devices
|
||||
and sharding_proto.iota_reshape_dims
|
||||
):
|
||||
tad = list(
|
||||
np.arange(math.prod(sharding_proto.tile_assignment_dimensions))
|
||||
.reshape(sharding_proto.iota_reshape_dims)
|
||||
.transpose(sharding_proto.iota_transpose_perm)
|
||||
.flat
|
||||
)
|
||||
else:
|
||||
tad = sharding_proto.tile_assignment_devices
|
||||
|
||||
# To use xla_sharding.py, we must have a xla_data_pb2.OpSharding.
|
||||
xla_sharding_v1_proto: xla_data_pb2.OpSharding = xla_data_pb2.OpSharding(
|
||||
type=int(sharding_proto.type),
|
||||
tile_assignment_dimensions=sharding_proto.tile_assignment_dimensions,
|
||||
tile_assignment_devices=tad,
|
||||
replicate_on_last_tile_dim=sharding_proto.replicate_on_last_tile_dim,
|
||||
last_tile_dims=sharding_proto.last_tile_dims,
|
||||
)
|
||||
# Shardy requires V2 sharding format.
|
||||
if config.use_shardy_partitioner.value:
|
||||
xla_sharding_v2_proto: xla_data_pb2.OpSharding = xla_data_pb2.OpSharding(
|
||||
type=int(sharding_proto.type),
|
||||
tile_assignment_dimensions=sharding_proto.tile_assignment_dimensions,
|
||||
tile_assignment_devices=sharding_proto.tile_assignment_devices,
|
||||
iota_reshape_dims=sharding_proto.iota_reshape_dims,
|
||||
iota_transpose_perm=sharding_proto.iota_transpose_perm,
|
||||
replicate_on_last_tile_dim=sharding_proto.replicate_on_last_tile_dim,
|
||||
last_tile_dims=sharding_proto.last_tile_dims,
|
||||
)
|
||||
else:
|
||||
xla_sharding_v2_proto = None
|
||||
if tf_context.executing_eagerly():
|
||||
raise ValueError(
|
||||
"A jit function with sharded arguments or results must be used under a `tf.function` context. "
|
||||
"See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#support-for-partitioning for a discussion")
|
||||
|
||||
tf_version = tuple(int(v) for v in tf.__version__.split(".")[:2])
|
||||
# apply_to_tensor comes from a tensorflow package, check the tensorflow
|
||||
# version to make sure that it has the sharding_v2_proto parameter.
|
||||
if tf_version < (2, 20):
|
||||
return xla_sharding.Sharding(proto=xla_sharding_v1_proto).apply_to_tensor(
|
||||
val, use_sharding_op=True
|
||||
)
|
||||
return xla_sharding.Sharding(proto=xla_sharding_v1_proto).apply_to_tensor(
|
||||
val, use_sharding_op=True, sharding_v2_proto=xla_sharding_v2_proto
|
||||
)
|
||||
|
||||
|
||||
def _register_checkpoint_pytrees():
|
||||
"""Registers TF custom container types as pytrees."""
|
||||
m = tf.Module()
|
||||
# The types here are automagically changed by TensorFlow's checkpointing
|
||||
# infrastructure.
|
||||
m.a = (tf.Module(), tf.Module())
|
||||
m.b = [tf.Module(), tf.Module()]
|
||||
m.c = {"a": tf.Module()}
|
||||
tuple_wrapper = type(m.a)
|
||||
list_wrapper = type(m.b)
|
||||
dict_wrapper = type(m.c)
|
||||
|
||||
# TF AutoTrackable swaps container types out for wrappers.
|
||||
assert tuple_wrapper is not tuple
|
||||
assert list_wrapper is not list
|
||||
assert dict_wrapper is not dict
|
||||
|
||||
jax.tree_util.register_pytree_node(tuple_wrapper, lambda xs:
|
||||
(tuple(xs), None), lambda _, xs: tuple(xs))
|
||||
|
||||
jax.tree_util.register_pytree_node(list_wrapper, lambda xs: (tuple(xs), None),
|
||||
lambda _, xs: list(xs))
|
||||
|
||||
jax.tree_util.register_pytree_node(
|
||||
dict_wrapper,
|
||||
lambda s: (tuple(s.values()), tuple(s.keys())),
|
||||
lambda k, xs: dict_wrapper(zip(k, xs)))
|
||||
|
||||
|
||||
_register_checkpoint_pytrees()
|
||||
Reference in New Issue
Block a user