786 lines
32 KiB
Python
786 lines
32 KiB
Python
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Callable, Sequence
|
|
import dataclasses
|
|
from functools import partial
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
|
|
from jax._src import config
|
|
from jax._src import core
|
|
from jax._src.core import typeof
|
|
from jax._src import source_info_util
|
|
from jax._src import linear_util as lu
|
|
from jax._src.partition_spec import PartitionSpec as P
|
|
from jax._src import mesh as mesh_lib
|
|
from jax._src.ad_util import Zero, SymbolicZero, add_jaxvals, add_jaxvals_p
|
|
from jax._src.core import Trace, Tracer, TraceTag
|
|
from jax._src.interpreters import partial_eval as pe
|
|
from jax._src.tree_util import (tree_unflatten, tree_flatten, PyTreeDef)
|
|
from jax._src.typing import Array
|
|
from jax._src.util import (unzip2, safe_map, safe_zip, split_list,
|
|
canonicalize_axis, moveaxis, memoize,
|
|
weakref_lru_cache, tuple_insert)
|
|
|
|
map, unsafe_map = safe_map, map
|
|
zip, unsafe_zip = safe_zip, zip
|
|
|
|
### vmappable typeclass
|
|
|
|
Vmappable = Any
|
|
Elt = Any
|
|
MapSpec = Any
|
|
AxisSize = Any
|
|
MeshAxis = Any
|
|
GetIdx = Callable[[], Tracer] # TODO(mattjj): revise this laziness
|
|
ToEltHandler = Callable[[Callable, GetIdx, Vmappable, MapSpec], Elt]
|
|
FromEltHandler = Callable[[Callable, AxisSize, Elt, MapSpec], Vmappable]
|
|
MakeIotaHandler = Callable[[AxisSize], Array]
|
|
|
|
def to_elt(trace: BatchTrace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt:
|
|
from jax._src import hijax # pytype: disable=import-error
|
|
handler = to_elt_handlers.get(type(x))
|
|
if handler:
|
|
return handler(partial(to_elt, trace, get_idx), get_idx, x, spec)
|
|
elif isinstance(spec, int) or spec is None:
|
|
spec = None if spec is None else canonicalize_axis(spec, len(np.shape(x)))
|
|
return (BatchTracer(trace, x, spec, source_info_util.current())
|
|
if spec is not None else x)
|
|
elif isinstance(typeof(x), hijax.HiType):
|
|
# TODO check possible errors
|
|
return BatchTracer(trace, x, spec, source_info_util.current())
|
|
else:
|
|
assert False, f'Unexpected type in ELT? {type(x)}'
|
|
|
|
|
|
to_elt_handlers: dict[type, ToEltHandler] = {}
|
|
|
|
def from_elt(trace: BatchTrace, axis_size: AxisSize, mesh_axis: MeshAxis,
|
|
sum_match: bool, i: int, x: Elt, spec: MapSpec) -> tuple[Vmappable, MapSpec]:
|
|
handler = from_elt_handlers.get(type(x))
|
|
if handler:
|
|
def _cont(axis_size, elt, axis):
|
|
return from_elt(trace, axis_size, mesh_axis, sum_match, i, elt, axis)[0]
|
|
return handler(_cont, axis_size, x, spec), spec
|
|
val, bdim = trace.to_batch_info(x)
|
|
bdim_inferred = bdim if spec is infer else spec
|
|
try:
|
|
return matchaxis(trace.axis_data, bdim, spec, val,
|
|
sum_match=sum_match), bdim_inferred
|
|
except SpecMatchError:
|
|
raise SpecMatchError(i, x.batch_dim, spec) from None
|
|
from_elt_handlers: dict[type, FromEltHandler] = {}
|
|
|
|
def make_iota(axis_size: AxisSize) -> Array:
|
|
# Callers of this utility, via batch() or vtile(), must be in a context
|
|
# where lax is importable.
|
|
from jax import lax # pytype: disable=import-error
|
|
handler = make_iota_handlers.get(type(axis_size))
|
|
if handler:
|
|
return handler(axis_size)
|
|
else:
|
|
return lax.iota('int32', int(axis_size))
|
|
make_iota_handlers: dict[type, MakeIotaHandler] = {}
|
|
|
|
def register_vmappable(data_type: type, spec_type: type, axis_size_type: type,
|
|
to_elt: Callable, from_elt: Callable,
|
|
make_iota: Callable | None):
|
|
vmappables[data_type] = (spec_type, axis_size_type)
|
|
spec_types.add(spec_type)
|
|
to_elt_handlers[data_type] = to_elt
|
|
from_elt_handlers[data_type] = from_elt
|
|
if make_iota: make_iota_handlers[axis_size_type] = make_iota
|
|
vmappables: dict[type, tuple[type, type]] = {}
|
|
spec_types: set[type] = set()
|
|
|
|
def unregister_vmappable(data_type: type) -> None:
|
|
pass # this used to do cleanup, but it was dumb
|
|
|
|
def is_vmappable(x: Any) -> bool:
|
|
return type(x) in vmappables
|
|
|
|
@lu.transformation_with_aux2
|
|
def flatten_fun_for_vmap(f: Callable,
|
|
store: lu.Store, in_tree: PyTreeDef, *args_flat):
|
|
py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
|
|
ans = f(*py_args, **py_kwargs)
|
|
ans, out_tree = tree_flatten(ans, is_leaf=is_vmappable)
|
|
store.store(out_tree)
|
|
return ans
|
|
|
|
|
|
### tracer
|
|
|
|
# TODO(mattjj): use a special sentinel type rather than None
|
|
NotMapped = type(None)
|
|
not_mapped = None
|
|
|
|
|
|
class BatchTracer(Tracer['BatchTrace']):
|
|
__slots__ = ['val', 'batch_dim', 'source_info']
|
|
|
|
_trace: BatchTrace
|
|
|
|
def __init__(self, trace: BatchTrace, val, batch_dim: NotMapped | int,
|
|
source_info: source_info_util.SourceInfo | None = None):
|
|
from jax._src import hijax # pytype: disable=import-error
|
|
|
|
aval = core.typeof(val)
|
|
if config.enable_checks.value:
|
|
# assert type(batch_dim) in (NotMapped, int)
|
|
if type(batch_dim) is int:
|
|
assert 0 <= batch_dim < len(aval.shape)
|
|
|
|
if trace.axis_data.spmd_name is not None:
|
|
if config._check_vma.value:
|
|
mat = aval.mat.update(
|
|
varying=aval.mat.varying - frozenset(
|
|
trace.axis_data.spmd_name))
|
|
aval = aval.update(manual_axis_type=mat)
|
|
if batch_dim is not_mapped:
|
|
aval = aval
|
|
elif type(batch_dim) is int:
|
|
aval = core.mapped_aval(aval.shape[batch_dim], batch_dim, aval)
|
|
elif isinstance(aval, hijax.HiType):
|
|
# pyrefly: ignore[bad-argument-type] # pyrefly#2499
|
|
aval = aval.dec_rank(trace.axis_data.size, batch_dim)
|
|
else:
|
|
raise Exception("batch dim should be int or `not_mapped`")
|
|
|
|
super().__init__(trace, aval)
|
|
self.val = val
|
|
self.batch_dim = batch_dim
|
|
self.source_info = source_info
|
|
|
|
def _short_repr(self):
|
|
return f"VmapTracer(aval={self.aval}, batched={core.typeof(self.val)})"
|
|
|
|
def full_lower(self):
|
|
if self.batch_dim is not_mapped:
|
|
return core.full_lower(self.val)
|
|
else:
|
|
return self
|
|
|
|
def _origin_msg(self):
|
|
if self.source_info is None:
|
|
return ""
|
|
return (f"\nThis BatchTracer with object id {id(self)} was created on line:"
|
|
f"\n {source_info_util.summarize(self.source_info)}")
|
|
|
|
def _contents(self):
|
|
return [('val', self.val), ('batch_dim', self.batch_dim)]
|
|
|
|
def get_referent(self):
|
|
if self.batch_dim is None or type(self.batch_dim) is int:
|
|
return core.get_referent(self.val)
|
|
else:
|
|
return self
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class AxisData:
|
|
name : Any
|
|
size : Any
|
|
# Only one of spmd_axis_name and explicit_mesh_axis is set.
|
|
spmd_name : Any
|
|
# short for private `_explicit_mesh_axis`. The public property is called
|
|
# `.explicit_mesh_axis`
|
|
_ema: tuple[Any, ...] | None
|
|
|
|
@property
|
|
def explicit_mesh_axis(self):
|
|
assert self._ema is None or isinstance(self._ema, tuple)
|
|
if self._ema is None:
|
|
return None
|
|
cur_mesh = mesh_lib.get_abstract_mesh()
|
|
if cur_mesh.empty:
|
|
return self._ema
|
|
ema0_type = cur_mesh._name_to_type[self._ema[0]]
|
|
assert all(cur_mesh._name_to_type[e] == ema0_type for e in self._ema)
|
|
if ema0_type != mesh_lib.AxisType.Explicit:
|
|
return None
|
|
return self._ema
|
|
|
|
def __repr__(self):
|
|
return (f'AxisData(name={self.name}, size={self.size},'
|
|
f' spmd_name={self.spmd_name},'
|
|
f' explicit_mesh_axis={self.explicit_mesh_axis})')
|
|
|
|
__str__ = __repr__
|
|
|
|
|
|
def get_sharding_for_vmap(axis_data, orig_sharding, axis):
|
|
val = axis_data.explicit_mesh_axis
|
|
new_spec = orig_sharding.spec.update(
|
|
partitions=tuple_insert(orig_sharding.spec, axis, val))
|
|
return orig_sharding.update(spec=new_spec)
|
|
|
|
|
|
class BatchTrace(Trace):
|
|
|
|
def __init__(self, parent_trace, tag, axis_data):
|
|
super().__init__()
|
|
self.parent_trace = parent_trace
|
|
assert isinstance(axis_data, AxisData)
|
|
self.axis_data = axis_data
|
|
self.tag = tag
|
|
self.requires_low = False
|
|
|
|
def to_batch_info(self, val):
|
|
if isinstance(val, BatchTracer) and val._trace.tag is self.tag:
|
|
return val.val, val.batch_dim
|
|
else:
|
|
return val, not_mapped
|
|
|
|
def cur_qdd(self, x):
|
|
val, _ = self.to_batch_info(x)
|
|
with core.set_current_trace(self.parent_trace):
|
|
return core.cur_qdd(val)
|
|
|
|
def process_primitive(self, p, tracers, params, /):
|
|
vals_in, dims_in = unzip2(map(self.to_batch_info, tracers))
|
|
args_not_mapped = all(bdim is not_mapped for bdim in dims_in)
|
|
if p in fancy_primitive_batchers:
|
|
# TODO(yashkatariya): Remove remove_explicit_mesh_axis_names when vmap
|
|
# mesh ctx is correctly set.
|
|
with (core.set_current_trace(self.parent_trace),
|
|
core.remove_explicit_mesh_axis_names(self.axis_data.explicit_mesh_axis)):
|
|
val_out, dim_out = fancy_primitive_batchers[p](
|
|
self.axis_data, vals_in, dims_in, **params)
|
|
src = source_info_util.current()
|
|
if p.multiple_results:
|
|
return [BatchTracer(self, x, d, src) if d is not not_mapped else x
|
|
for x, d in zip(val_out, dim_out)]
|
|
else:
|
|
return (BatchTracer(self, val_out, dim_out, src)
|
|
if dim_out is not not_mapped else val_out)
|
|
elif args_not_mapped: # Not all primitives have batching rules defined
|
|
avals = tuple(core.typeof(x) for x in vals_in)
|
|
return p.bind_with_trace(self.parent_trace, tuple(vals_in), avals,
|
|
dict(params))
|
|
else:
|
|
raise NotImplementedError(f"Batching rule for '{p}' not implemented")
|
|
|
|
def process_call(self, call_primitive, f, tracers, params, /):
|
|
assert call_primitive.multiple_results
|
|
params = dict(params, name=params.get('name', f.__name__))
|
|
vals, dims = unzip2(map(self.to_batch_info, tracers))
|
|
f_, dims_out = batch_subtrace(f, self.tag, self.axis_data, tuple(dims))
|
|
|
|
with core.set_current_trace(self.parent_trace):
|
|
vals_out = call_primitive.bind(*vals, subfuns=(f_,), **params)
|
|
src = source_info_util.current()
|
|
return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out())]
|
|
|
|
def process_custom_jvp_call(self, prim, fun, jvp, tracers, /, *, symbolic_zeros):
|
|
in_vals, in_dims = unzip2(map(self.to_batch_info, tracers))
|
|
fun, out_dims1 = batch_subtrace(fun, self.tag, self.axis_data, in_dims)
|
|
jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.tag, self.axis_data, in_dims)
|
|
avals = tuple(core.typeof(x) for x in in_vals)
|
|
out_vals = prim.bind_with_trace(self.parent_trace, tuple(in_vals), avals,
|
|
dict(subfuns=(fun, jvp), symbolic_zeros=symbolic_zeros))
|
|
fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
|
|
src = source_info_util.current()
|
|
return [BatchTracer(self, v, d, src) for v, d in zip(out_vals, out_dims)]
|
|
|
|
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, /, *, out_trees,
|
|
symbolic_zeros):
|
|
in_vals, in_dims = unzip2(map(self.to_batch_info, tracers))
|
|
fwd_in_dims = [d for in_dim in in_dims for d in [in_dim, not_mapped]]
|
|
|
|
fun, out_dims1 = batch_subtrace(fun, self.tag, self.axis_data, in_dims)
|
|
fwd, out_dims2 = batch_subtrace(fwd, self.tag, self.axis_data, fwd_in_dims)
|
|
|
|
def bwd_in_dims():
|
|
_, _, input_fwds = out_trees()
|
|
pruned_dims = iter(out_dims2())
|
|
full_dims = [next(pruned_dims) if f is None else in_dims[f] for f in input_fwds]
|
|
return [*full_dims, *pruned_dims]
|
|
|
|
bwd = batch_custom_vjp_bwd(bwd, self.tag, self.axis_data, bwd_in_dims, in_dims)
|
|
avals = tuple(core.typeof(x) for x in in_vals)
|
|
out_vals = prim.bind_with_trace(self.parent_trace,
|
|
tuple(in_vals), avals,
|
|
dict(subfuns=(fun, fwd, bwd), out_trees=out_trees, symbolic_zeros=symbolic_zeros))
|
|
fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
|
|
if not fst:
|
|
_, res_tree, input_fwds = out_trees()
|
|
num_res = res_tree.num_leaves - sum(f is not None for f in input_fwds)
|
|
_, out_dims = split_list(out_dims, [num_res])
|
|
src = source_info_util.current()
|
|
return [BatchTracer(self, v, d, src) for v, d in zip(out_vals, out_dims)]
|
|
|
|
### API for batching callables with vmappable inputs and outputs
|
|
|
|
def batch(fun: lu.WrappedFun, axis_data,
|
|
in_dims, out_dim_dests, sum_match=False) -> lu.WrappedFun:
|
|
# we split up _batch_inner and _batch_outer for the leak checker
|
|
f = _batch_inner(fun, axis_data, out_dim_dests, sum_match)
|
|
return _batch_outer(f, axis_data, in_dims)
|
|
|
|
@lu.transformation2
|
|
def _batch_outer(f, axis_data, in_dims, *in_vals):
|
|
tag = TraceTag()
|
|
with source_info_util.transform_name_stack('vmap'):
|
|
outs, out_dim_srcs, trace = f(tag, in_dims, *in_vals)
|
|
with core.ensure_no_leaks(trace): del trace
|
|
return outs, out_dim_srcs
|
|
|
|
@lu.transformation2
|
|
def _batch_inner(f: Callable, axis_data, out_dim_dests, sum_match, tag, in_dims, *in_vals):
|
|
in_dims = in_dims() if callable(in_dims) else in_dims
|
|
with core.take_current_trace() as parent_trace:
|
|
trace = BatchTrace(parent_trace, tag, axis_data)
|
|
idx = memoize(lambda: BatchTracer(trace, make_iota(axis_data.size), 0,
|
|
source_info_util.current()))
|
|
with core.set_current_trace(parent_trace):
|
|
in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims) # pyrefly: ignore[bad-argument-type] # pyrefly#2385
|
|
# TODO(yashkatariya): Instead of `add_explicit_mesh_axis_names`, we should
|
|
# create a new mesh by removing the axis_data.explicit_mesh_axis from it.
|
|
with (core.set_current_trace(trace),
|
|
core.extend_axis_env_nd([(axis_data.name, axis_data.size)]),
|
|
core.add_spmd_axis_names(axis_data.spmd_name),
|
|
core.add_explicit_mesh_axis_names(axis_data.explicit_mesh_axis)):
|
|
outs = f(*in_tracers)
|
|
out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
|
|
out_vals, out_dim_srcs = unzip2(
|
|
map(partial(from_elt, trace, axis_data.size, axis_data.explicit_mesh_axis, sum_match),
|
|
range(len(outs)), outs, out_dim_dests)) # pyrefly: ignore[bad-argument-type] # pyrefly#2385
|
|
return out_vals, out_dim_srcs, trace
|
|
|
|
### API for batching functions with jaxpr type inputs and outputs
|
|
|
|
# Returns `out_dims` as a second tuple component.
|
|
# The result of `f` should be a `FlatTree`.
|
|
def batch_subtrace_2(f, tag, axis_data, in_dims, in_vals):
|
|
with core.take_current_trace() as parent_trace:
|
|
trace = BatchTrace(parent_trace, tag, axis_data)
|
|
with core.set_current_trace(trace):
|
|
in_dims = in_dims() if callable(in_dims) else in_dims
|
|
in_tracers = [BatchTracer(trace, x, dim, source_info_util.current())
|
|
if dim is not None else x for x, dim in zip(in_vals, in_dims)] # pyrefly: ignore[bad-argument-type] # pyrefly#2385
|
|
outs = f(*in_tracers)
|
|
out_vals, out_dims = outs.map(trace.to_batch_info).unzip2()
|
|
return out_vals, list(out_dims)
|
|
|
|
@lu.transformation_with_aux2
|
|
def batch_subtrace(f, store, tag, axis_data, in_dims, *in_vals):
|
|
with core.take_current_trace() as parent_trace:
|
|
trace = BatchTrace(parent_trace, tag, axis_data)
|
|
with core.set_current_trace(trace):
|
|
in_dims = in_dims() if callable(in_dims) else in_dims
|
|
in_tracers = [BatchTracer(trace, x, dim, source_info_util.current())
|
|
if dim is not None else x for x, dim in zip(in_vals, in_dims)] # pyrefly: ignore[bad-argument-type] # pyrefly#2385
|
|
outs = f(*in_tracers)
|
|
out_vals, out_dims = unzip2(map(trace.to_batch_info, outs))
|
|
store.store(out_dims)
|
|
return out_vals
|
|
|
|
### API for batching jaxprs
|
|
|
|
def batch_jaxpr2(
|
|
closed_jaxpr: core.ClosedJaxpr,
|
|
axis_data,
|
|
in_axes: tuple[int | NotMapped, ...],
|
|
) -> tuple[core.ClosedJaxpr, tuple[int | NotMapped, ...]]:
|
|
return _batch_jaxpr2(closed_jaxpr, axis_data, tuple(in_axes))
|
|
|
|
@weakref_lru_cache
|
|
def _batch_jaxpr2(
|
|
closed_jaxpr: core.ClosedJaxpr,
|
|
axis_data,
|
|
in_axes: tuple[int | NotMapped, ...],
|
|
) -> tuple[core.ClosedJaxpr, tuple[int | NotMapped, ...]]:
|
|
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr),
|
|
debug_info=closed_jaxpr.jaxpr.debug_info)
|
|
f, out_axes = _batch_jaxpr_inner(f, axis_data)
|
|
f = _batch_jaxpr_outer(f, axis_data, in_axes)
|
|
avals_in2 = []
|
|
for aval, b in unsafe_zip(closed_jaxpr.in_avals, in_axes):
|
|
if b is not_mapped:
|
|
avals_in2.append(aval)
|
|
else:
|
|
aval = core.unmapped_aval(
|
|
axis_data.size, b, aval, axis_data.explicit_mesh_axis)
|
|
if axis_data.spmd_name is not None:
|
|
if config._check_vma.value:
|
|
mat = aval.mat.update( # pyrefly: ignore[missing-attribute]
|
|
varying=aval.mat.varying | frozenset(axis_data.spmd_name)) # pyrefly: ignore[missing-attribute]
|
|
aval = aval.update(manual_axis_type=mat)
|
|
avals_in2.append(aval)
|
|
jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in2)
|
|
return core.ClosedJaxpr(jaxpr_out, consts), out_axes()
|
|
|
|
def batch_jaxpr(closed_jaxpr, axis_data, in_batched, instantiate):
|
|
inst = tuple(instantiate) if isinstance(instantiate, list) else instantiate
|
|
return _batch_jaxpr(closed_jaxpr, axis_data, tuple(in_batched), inst)
|
|
|
|
def _batch_jaxpr(closed_jaxpr, axis_data, in_batched, instantiate):
|
|
assert (isinstance(instantiate, bool) or
|
|
isinstance(instantiate, (list, tuple)) and
|
|
all(isinstance(b, bool) for b in instantiate))
|
|
if isinstance(instantiate, bool):
|
|
instantiate = [instantiate] * len(closed_jaxpr.out_avals)
|
|
in_axes = [0 if b else not_mapped for b in in_batched]
|
|
out_axes_dest = [0 if inst else zero_if_mapped for inst in instantiate]
|
|
return batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest)
|
|
|
|
def batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest):
|
|
return _batch_jaxpr_axes(closed_jaxpr, axis_data, tuple(in_axes), tuple(out_axes_dest))
|
|
|
|
@weakref_lru_cache
|
|
def _batch_jaxpr_axes(closed_jaxpr: core.ClosedJaxpr,
|
|
axis_data: AxisData,
|
|
in_axes: Sequence[int], out_axes_dest: Sequence[int]):
|
|
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr),
|
|
debug_info=closed_jaxpr.jaxpr.debug_info)
|
|
f, out_axes = _batch_jaxpr_inner(f, axis_data)
|
|
f, out_batched = _match_axes_jaxpr(f, axis_data, out_axes_dest, out_axes)
|
|
f = _batch_jaxpr_outer(f, axis_data, in_axes)
|
|
avals_in = [core.unmapped_aval(axis_data.size, b, aval,
|
|
axis_data.explicit_mesh_axis)
|
|
if b is not not_mapped
|
|
else aval for aval, b in unsafe_zip(closed_jaxpr.in_avals, in_axes)]
|
|
jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in)
|
|
return core.ClosedJaxpr(jaxpr_out, consts), out_batched()
|
|
|
|
@lu.transformation_with_aux2
|
|
def _batch_jaxpr_inner(f, store, axis_data, tag, in_axes, *in_vals):
|
|
with core.take_current_trace() as parent_trace:
|
|
trace = BatchTrace(parent_trace, tag, axis_data)
|
|
in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
|
|
for val, dim in zip(in_vals, in_axes)]
|
|
# TODO(yashkatariya): Instead of `add_explicit_mesh_axis_names`, we should
|
|
# create a new mesh by removing the axis_data.explicit_mesh_axis from it.
|
|
with (core.set_current_trace(trace),
|
|
core.extend_axis_env_nd([(axis_data.name, axis_data.size)]),
|
|
core.add_spmd_axis_names(axis_data.spmd_name),
|
|
core.add_explicit_mesh_axis_names(axis_data.explicit_mesh_axis)):
|
|
outs = f(*in_tracers)
|
|
out_vals, out_axes = unzip2(map(trace.to_batch_info, outs))
|
|
store.store(out_axes)
|
|
return out_vals
|
|
|
|
@lu.transformation_with_aux2
|
|
def _match_axes_jaxpr(f, store, axis_data, out_axes_dest, out_axes, trace, in_axes,
|
|
*in_vals):
|
|
out_vals = f(trace, in_axes, *in_vals)
|
|
out_axes = out_axes()
|
|
out_axes_dest = [(None if src is not_mapped else 0)
|
|
if dst is zero_if_mapped else dst
|
|
for src, dst in unsafe_zip(out_axes, out_axes_dest)]
|
|
if len(out_axes_dest) != len(out_axes):
|
|
out_axis_dest, = out_axes_dest
|
|
out_axes_dest = [out_axis_dest] * len(out_axes)
|
|
out_vals = map(partial(matchaxis, axis_data), out_axes, out_axes_dest, out_vals)
|
|
out_batched = [dst is not None for dst in out_axes_dest]
|
|
store.store(out_batched)
|
|
return out_vals
|
|
|
|
@lu.transformation2
|
|
def _batch_jaxpr_outer(f, axis_data, in_dims, *in_vals):
|
|
in_dims = in_dims() if callable(in_dims) else in_dims
|
|
in_dims = [canonicalize_axis(ax, np.ndim(x)) if isinstance(ax, int)
|
|
else ax for x, ax in unsafe_zip(in_vals, in_dims)] # pyrefly: ignore[bad-argument-type] # pyrefly#2385
|
|
tag = TraceTag()
|
|
return f(tag, in_dims, *in_vals)
|
|
|
|
def _merge_bdims(x, y):
|
|
if x == y:
|
|
return x
|
|
elif x is not_mapped:
|
|
return y
|
|
elif y is not_mapped:
|
|
return x
|
|
else:
|
|
return x # arbitrary
|
|
|
|
class ZeroIfMapped: pass
|
|
zero_if_mapped = ZeroIfMapped()
|
|
|
|
### functions for handling custom_vjp
|
|
|
|
@lu.transformation_with_aux2
|
|
def batch_custom_jvp_subtrace(f, store, tag, axis_data, in_dims, *in_vals):
|
|
with core.take_current_trace() as parent_trace:
|
|
trace = BatchTrace(parent_trace, tag, axis_data)
|
|
in_tracers = [val if dim is None else
|
|
SymbolicZero(core.mapped_aval(axis_data.size, dim, val.aval))
|
|
if type(val) is SymbolicZero else BatchTracer(trace, val, dim)
|
|
for val, dim in zip(in_vals, in_dims * 2)]
|
|
with core.set_current_trace(trace):
|
|
out_tracers: list[BatchTracer | SymbolicZero] = f(*in_tracers)
|
|
out_vals, out_dims = unzip2(map(trace.to_batch_info, out_tracers))
|
|
out_primals, out_tangents = split_list(out_vals, [len(out_vals) // 2])
|
|
out_primal_bds, out_tangent_bds = split_list(out_dims, [len(out_vals) // 2])
|
|
out_dims = map(_merge_bdims, out_primal_bds, out_tangent_bds)
|
|
out_primals = map(partial(matchaxis, axis_data), out_primal_bds, out_dims,
|
|
out_primals)
|
|
out_tangents = map(partial(_matchaxis_symzeros, axis_data),
|
|
out_tangent_bds, out_dims, out_tangents)
|
|
store.store(out_dims)
|
|
return out_primals + out_tangents
|
|
|
|
def batch_custom_vjp_bwd(bwd: lu.WrappedFun, tag: core.TraceTag,
|
|
axis_data: AxisData,
|
|
in_dims: Callable[[], Sequence[int | None]],
|
|
out_dim_dests: Sequence[int | None]) -> lu.WrappedFun:
|
|
def new_bwd(*args):
|
|
in_dims_ = in_dims() if callable(in_dims) else in_dims
|
|
args = [SymbolicZero(core.mapped_aval(axis_data.size, dim, x.aval))
|
|
if type(x) is SymbolicZero else x
|
|
for x, dim in zip(args, in_dims_)]
|
|
in_dims_ = [None if type(x) is SymbolicZero else d
|
|
for x, d in zip(args, in_dims_)]
|
|
bwd_, out_dims_thunk = batch_subtrace(bwd, tag, axis_data, in_dims_)
|
|
bwd_ = _match_axes_and_sum(bwd_, axis_data, out_dims_thunk, out_dim_dests)
|
|
return bwd_.call_wrapped(*args)
|
|
return lu.wrap_init(new_bwd, debug_info=bwd.debug_info)
|
|
|
|
@lu.transformation2
|
|
def _match_axes_and_sum(f, axis_data, out_dims_thunk, out_dim_dests, *in_vals):
|
|
# this is like _match_axes, but we do reduce-sums as needed
|
|
out_vals = f(*in_vals)
|
|
return map(partial(_matchaxis_symzeros, axis_data, sum_match=True),
|
|
out_dims_thunk(), out_dim_dests, out_vals)
|
|
|
|
def _matchaxis_symzeros(axis_data, src, dst, x, sum_match=False):
|
|
# Just like `matchaxis`, but handles symbolic zeros using ad_util.py
|
|
# TODO(mattjj): dedup with matchaxis
|
|
if isinstance(x, (Zero, SymbolicZero)):
|
|
if src == dst:
|
|
return x
|
|
elif type(src) == type(dst) == int:
|
|
aval = core.mapped_aval(axis_data.size, src, x.aval)
|
|
return type(x)(core.unmapped_aval(axis_data.size, dst, aval,
|
|
axis_data.explicit_mesh_axis))
|
|
elif src is not_mapped and dst is not not_mapped:
|
|
return type(x)(core.unmapped_aval(axis_data.size, dst, x.aval,
|
|
axis_data.explicit_mesh_axis))
|
|
elif dst is not_mapped and sum_match:
|
|
return type(x)(core.mapped_aval(axis_data.size, src, x.aval))
|
|
else:
|
|
raise ValueError((axis_data.name, x, src, dst))
|
|
else:
|
|
return matchaxis(axis_data, src, dst, x, sum_match=sum_match)
|
|
|
|
|
|
### utilities for defining primitives' batching rules
|
|
|
|
fancy_primitive_batchers: dict[core.Primitive, Callable] = {}
|
|
|
|
# backwards compat shim. TODO: delete
|
|
class AxisPrimitiveBatchersProxy:
|
|
def __setitem__(self, prim, batcher):
|
|
def wrapped(axis_data, vals, dims, **params):
|
|
return batcher(axis_data.size, axis_data.name, None, vals, dims, **params)
|
|
fancy_primitive_batchers[prim] = wrapped
|
|
axis_primitive_batchers = AxisPrimitiveBatchersProxy()
|
|
|
|
# backwards compat shim. TODO: delete
|
|
class PrimitiveBatchersProxy:
|
|
def __setitem__(self, prim, batcher):
|
|
def wrapped(axis_data, vals, dims, **params):
|
|
del axis_data
|
|
if all(d is None for d in dims):
|
|
o = prim.bind(*vals, **params)
|
|
return (o, [None] * len(o)) if prim.multiple_results else (o, None)
|
|
return batcher(vals, dims, **params)
|
|
fancy_primitive_batchers[prim] = wrapped
|
|
|
|
def __delitem__(self, prim):
|
|
del fancy_primitive_batchers[prim]
|
|
primitive_batchers = PrimitiveBatchersProxy()
|
|
|
|
def defvectorized(prim):
|
|
fancy_primitive_batchers[prim] = partial(vectorized_batcher, prim)
|
|
|
|
def vectorized_batcher(prim, axis_data, batched_args, batch_dims, **params):
|
|
assert not prim.multiple_results
|
|
if all(d is None for d in batch_dims):
|
|
return prim.bind(*batched_args, **params), None
|
|
assert all(batch_dims[0] == bd for bd in batch_dims[1:]), batch_dims
|
|
return prim.bind(*batched_args, **params), batch_dims[0]
|
|
|
|
def defbroadcasting(prim):
|
|
fancy_primitive_batchers[prim] = partial(broadcast_batcher, prim)
|
|
|
|
def broadcast_batcher(prim, axis_data, args, dims, **params):
|
|
assert len(args) > 1
|
|
if all(d is None for d in dims):
|
|
o = prim.bind(*args, **params)
|
|
return (o, [None] * len(o)) if prim.multiple_results else (o, None)
|
|
shape, dim = next((x.shape, d) for x, d in zip(args, dims)
|
|
if d is not not_mapped)
|
|
if all(core.definitely_equal_shape(shape, x.shape) and d == dim
|
|
for x, d in zip(args, dims) if np.ndim(x)):
|
|
# if there's only agreeing batch dims and scalars, just call the primitive
|
|
args = spmd_names_insert_pvary(*args)
|
|
out = prim.bind(*args, **params)
|
|
return (out, (dim,) * len(out)) if prim.multiple_results else (out, dim)
|
|
else:
|
|
# We pass size of 1 here because (1) at least one argument has a real batch
|
|
# dimension and (2) all unmapped axes can have a singleton axis inserted and
|
|
# then rely on the primitive's built-in broadcasting.
|
|
args = [bdim_at_front(x, d, 1) if np.ndim(x) else x
|
|
for x, d in zip(args, dims)]
|
|
ndim = max(np.ndim(x) for x in args) # special-case scalar broadcasting
|
|
args = [_handle_scalar_broadcasting(ndim, x, d) for x, d in zip(args, dims)]
|
|
out = prim.bind(*args, **params)
|
|
return (out, (0,) * len(out)) if prim.multiple_results else (out, 0)
|
|
|
|
def _handle_scalar_broadcasting(nd, x, d):
|
|
# Callers of this utility, via broadcast_batcher() or defbroadcasting(),
|
|
# must be in a context where lax is importable.
|
|
from jax import lax # pytype: disable=import-error
|
|
return (x if d is not_mapped or nd == np.ndim(x) else
|
|
lax.expand_dims(x, tuple(range(np.ndim(x), nd))))
|
|
|
|
def defreducer(prim):
|
|
fancy_primitive_batchers[prim] = partial(reducer_batcher, prim)
|
|
|
|
def reducer_batcher(prim, axis_data, batched_args, batch_dims, axes,
|
|
**params):
|
|
if all(d is None for d in batch_dims):
|
|
return prim.bind(*batched_args, axes=axes, **params), None
|
|
def out_axis(axes, axis):
|
|
return int(list(np.delete(np.arange(operand.ndim), axes)).index(axis))
|
|
operand, = batched_args
|
|
bdim, = batch_dims
|
|
if isinstance(bdim, int):
|
|
axes = tuple(np.where(np.less(axes, bdim), axes, np.add(axes, 1)))
|
|
bdim_out = out_axis(axes, bdim)
|
|
if 'input_shape' in params:
|
|
params = dict(params, input_shape=operand.shape)
|
|
if 'out_sharding' in params:
|
|
out_s = params['out_sharding']
|
|
if out_s is not None:
|
|
params = dict(params,
|
|
out_sharding=get_sharding_for_vmap(axis_data, out_s, bdim_out))
|
|
return prim.bind(operand, axes=axes, **params), bdim_out
|
|
else:
|
|
assert False
|
|
|
|
def expand_dims_batcher(prim, args, dims, **params):
|
|
"""A batching rule for primitives that support matching leading batch
|
|
dimensions in all arguments.
|
|
"""
|
|
size, = {x.shape[bd] for x, bd in zip(args, dims) if bd is not not_mapped}
|
|
args = [bdim_at_front(x, bd, size) for x, bd in zip(args, dims)]
|
|
out = prim.bind(*args, **params)
|
|
return (out, (0,) * len(out)) if prim.multiple_results else (out, 0)
|
|
|
|
### general utilities for manipulating axes on jaxpr types (not vmappables)
|
|
|
|
def broadcast(x, sz, axis, mesh_axis):
|
|
# Callers of this utility must be in a context where lax is importable.
|
|
from jax import lax # pytype: disable=import-error
|
|
shape = list(np.shape(x))
|
|
shape.insert(axis, sz)
|
|
broadcast_dims = tuple(np.delete(np.arange(len(shape)), axis))
|
|
x_aval = core.typeof(x)
|
|
if x_aval.sharding.mesh.empty:
|
|
mesh_axis = None
|
|
new_spec = P(*tuple_insert(x_aval.sharding.spec, axis, mesh_axis))
|
|
sharding = x_aval.sharding.update(spec=new_spec)
|
|
# TODO(dougalm, yashkatariya): Delete this context manager once we figure
|
|
# out how to ensure jaxpr arguments always have the context mesh.
|
|
with mesh_lib.use_abstract_mesh(sharding.mesh):
|
|
x, = spmd_names_insert_pvary(lax.broadcast_in_dim(
|
|
x, shape, broadcast_dims, out_sharding=sharding))
|
|
return x
|
|
|
|
def spmd_names_insert_pvary(*args):
|
|
if (config._check_vma.value and
|
|
(spmd_names := core.get_axis_env().spmd_axis_names)):
|
|
return [core.pvary(a, tuple(spmd_names - aval.mat.varying))
|
|
if isinstance(aval := typeof(a), core.ShapedArray) else a
|
|
for a in args]
|
|
return args
|
|
|
|
def matchaxis(axis_data, src, dst, x, sum_match=False):
|
|
try:
|
|
_ = core.typeof(x)
|
|
except TypeError as e:
|
|
raise TypeError(f"Output from batched function {x!r} with type "
|
|
f"{type(x)} is not a valid JAX type") from e
|
|
if src == dst or dst is infer:
|
|
return x
|
|
elif type(src) == type(dst) == int:
|
|
return moveaxis(x, src, dst)
|
|
elif src is not_mapped and type(dst) is int:
|
|
return broadcast(x, axis_data.size, canonicalize_axis(dst, np.ndim(x) + 1),
|
|
axis_data.explicit_mesh_axis)
|
|
elif src is not_mapped and dst is sum_axis:
|
|
return x
|
|
elif dst is not_mapped and sum_match or dst is sum_axis:
|
|
return x.sum(src)
|
|
else:
|
|
if (not isinstance(axis_data.name, core._TempAxisName) and
|
|
axis_data.name is not core.no_axis_name):
|
|
raise ValueError(
|
|
f'vmap has mapped output (axis_name={axis_data.name}) but out_axes is'
|
|
f' {dst}')
|
|
else:
|
|
raise SpecMatchError(None, None, None)
|
|
|
|
class SpecMatchError(Exception):
|
|
def __init__(self, leaf_idx, src, dst):
|
|
self.leaf_idx = leaf_idx
|
|
self.src = src
|
|
self.dst = dst
|
|
|
|
def bdim_at_front(x, bdim, size, mesh_axis=None):
|
|
if bdim is not_mapped:
|
|
return broadcast(x, size, 0, mesh_axis=mesh_axis)
|
|
else:
|
|
return moveaxis(x, bdim, 0)
|
|
|
|
|
|
def add_batched(axis_data, batched_args, batch_dims):
|
|
bdx, bdy = batch_dims
|
|
x, y = batched_args
|
|
if bdx is None and bdy is None:
|
|
return add_jaxvals(x, y), None
|
|
mesh_axis = axis_data.explicit_mesh_axis
|
|
if bdx == bdy:
|
|
return add_jaxvals(x, y), bdx
|
|
elif bdx is not_mapped:
|
|
x = broadcast(x, y.shape[bdy], bdy, mesh_axis=mesh_axis)
|
|
return add_jaxvals(x, y), bdy
|
|
elif bdy is not_mapped:
|
|
y = broadcast(y, x.shape[bdx], bdx, mesh_axis=mesh_axis)
|
|
return add_jaxvals(x, y), bdx
|
|
else:
|
|
x = moveaxis(x, bdx, bdy)
|
|
return add_jaxvals(x, y), bdy
|
|
|
|
fancy_primitive_batchers[add_jaxvals_p] = add_batched
|
|
|
|
### mutable arrays
|
|
|
|
defvectorized(core.ref_p)
|
|
|
|
### hijax
|
|
|
|
class Sum: pass
|
|
sum_axis = Sum()
|
|
spec_types.add(Sum)
|
|
|
|
class Infer: pass
|
|
infer = Infer()
|
|
spec_types.add(Infer)
|