hand
This commit is contained in:
@@ -0,0 +1,302 @@
|
||||
# Copyright 2026 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Stateful, implicitly-updated PRNG implementation based on mutable refs.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import operator
|
||||
from collections.abc import Sequence
|
||||
|
||||
from jax._src import api_util
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src import random
|
||||
from jax._src import ref
|
||||
from jax._src import tree_util
|
||||
from jax._src import typing
|
||||
from jax._src.state import primitives as ref_primitives
|
||||
from jax._src.state import types as state_types
|
||||
from jax._src.typing import Array, ArrayLike, DTypeLike
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _canonicalize_size(size: int | Sequence[int] | None, *args: ArrayLike) -> tuple[int, ...]:
|
||||
if size is None:
|
||||
return np.broadcast_shapes(*(np.shape(arg) for arg in args))
|
||||
elif isinstance(size, (int, np.number)):
|
||||
return (operator.index(size),)
|
||||
else:
|
||||
return tuple(map(operator.index, size))
|
||||
|
||||
|
||||
@tree_util.register_dataclass
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class StatefulPRNG:
|
||||
"""Stateful JAX random generator.
|
||||
|
||||
This should be instantiated using the :func:`jax.experimental.random.stateful_rng` function.
|
||||
|
||||
Attributes:
|
||||
_base_key: a typed JAX PRNG key object (see :func:`jax.random.key`).
|
||||
_counter: a scalar integer wrapped in a :class:`jax.Ref`.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> from jax.experimental import random
|
||||
>>> rng = random.stateful_rng(42)
|
||||
>>> rng
|
||||
StatefulPRNG(_base_key=Array((), dtype=key<fry>) overlaying:
|
||||
[ 0 42], _counter=Ref(0, dtype=int32, weak_type=True))
|
||||
"""
|
||||
_base_key: Array
|
||||
_counter: core.Ref
|
||||
|
||||
def __post_init__(self):
|
||||
if self._base_key is api_util.SENTINEL:
|
||||
return
|
||||
if not (isinstance(self._base_key, Array)
|
||||
and dtypes.issubdtype(self._base_key.dtype, dtypes.prng_key)):
|
||||
raise ValueError(f"Expected base_key to be a typed PRNG key; got {self._base_key}")
|
||||
|
||||
# TODO(jakevdp): how to validate a traced mutable array?
|
||||
if not (isinstance(self._counter, core.Ref) or
|
||||
(isinstance(self._counter, core.Tracer)
|
||||
and isinstance(self._counter.aval, state_types.AbstractRef))):
|
||||
raise ValueError(f"Expected counter to be a scalar integer ref; got {self._counter}")
|
||||
|
||||
def key(self, shape: int | Sequence[int] = ()) -> Array:
|
||||
"""Generate a new JAX PRNGKey, updating the internal state.
|
||||
|
||||
Args:
|
||||
shape: an optional shape if returning multiple keys.
|
||||
|
||||
Returns:
|
||||
A new, independent PRNG key with the same impl/dtype as
|
||||
``self._base_key``.
|
||||
|
||||
Examples:
|
||||
>>> from jax.experimental import random
|
||||
>>> rng = random.stateful_rng(0)
|
||||
>>> rng.key()
|
||||
Array((), dtype=key<fry>) overlaying:
|
||||
[1797259609 2579123966]
|
||||
>>> rng.key()
|
||||
Array((), dtype=key<fry>) overlaying:
|
||||
[ 928981903 3453687069]
|
||||
"""
|
||||
if self._base_key.shape:
|
||||
# TODO(jakevdp): better error message.
|
||||
raise ValueError("cannot operate on split stateful generator")
|
||||
|
||||
key = random.fold_in(self._base_key, ref_primitives.ref_get(self._counter))
|
||||
ref_primitives.ref_addupdate(self._counter, ..., 1) # pytype bug?
|
||||
shape_tuple = _canonicalize_size(shape)
|
||||
return random.split(key, shape_tuple) if shape_tuple else key
|
||||
|
||||
def random(
|
||||
self,
|
||||
size: int | Sequence[int] | None = None,
|
||||
dtype: DTypeLike = float,
|
||||
):
|
||||
"""Return random floats in the half-open interval [0.0, 1.0)."""
|
||||
# TODO(jakevdp): write docstring
|
||||
return random.uniform(self.key(), shape=_canonicalize_size(size), dtype=dtype)
|
||||
|
||||
|
||||
def uniform(
|
||||
self,
|
||||
low: ArrayLike = 0,
|
||||
high: ArrayLike = 1,
|
||||
size: int | Sequence[int] | None = None,
|
||||
*,
|
||||
dtype: DTypeLike = float,
|
||||
) -> Array:
|
||||
"""Draw uniformly distributed pseudorandom values."""
|
||||
# TODO(jakevdp): write docstring
|
||||
return random.uniform(self.key(), _canonicalize_size(size, low, high),
|
||||
minval=low, maxval=high, dtype=dtype)
|
||||
|
||||
def normal(
|
||||
self,
|
||||
loc: ArrayLike = 0,
|
||||
scale: ArrayLike = 1,
|
||||
size: int | Sequence[int] | None = None,
|
||||
*,
|
||||
dtype: DTypeLike = float,
|
||||
) -> Array:
|
||||
"""Draw normally-distributed pseudorandom values."""
|
||||
# TODO(jakevdp): write docstring
|
||||
norm = random.normal(self.key(), _canonicalize_size(size, loc, scale), dtype)
|
||||
return (jnp.asarray(loc) + jnp.asarray(scale) * norm).astype(dtype)
|
||||
|
||||
def integers(
|
||||
self,
|
||||
low: ArrayLike,
|
||||
high: ArrayLike | None = None,
|
||||
size: int | Sequence[int] | None = None,
|
||||
*,
|
||||
dtype: DTypeLike = int,
|
||||
) -> Array:
|
||||
"""Draw pseudorandom integers."""
|
||||
# TODO(jakevdp): write docstring
|
||||
if high is None:
|
||||
low, high = 0, low
|
||||
return random.randint(self.key(), _canonicalize_size(size, low, high),
|
||||
minval=low, maxval=high, dtype=dtype)
|
||||
|
||||
def split(self, num: int | Sequence[int]) -> StatefulPRNG:
|
||||
"""Create independent child generators suitable for use in :func:`jax.vmap`.
|
||||
|
||||
Args:
|
||||
num: integer or sequence of integers specifying the split shape
|
||||
|
||||
Returns:
|
||||
a single StatefulPRNG object with split contents, suitable for use
|
||||
with :func:`jax.vmap`
|
||||
|
||||
Examples:
|
||||
>>> import jax
|
||||
>>> from jax.experimental import random
|
||||
>>> rng = random.stateful_rng(123)
|
||||
>>> x = jax.numpy.zeros(3)
|
||||
>>> def f(rng, x):
|
||||
... return x + rng.uniform()
|
||||
>>> jax.vmap(f)(rng.split(3), x)
|
||||
Array([0.35525954, 0.21937883, 0.5336956 ], dtype=float32)
|
||||
|
||||
See also:
|
||||
- :meth:`jax.experimental.random.StatefulPRNG.spawn`: This is similar to ``split``, but
|
||||
returns a Python list of :class:`StatefulPRNG`` objects.
|
||||
"""
|
||||
return StatefulPRNG(
|
||||
_base_key=self.key(num),
|
||||
_counter=ref.new_ref(jnp.zeros(num, dtype=int))
|
||||
)
|
||||
|
||||
def spawn(self, n_children: int) -> list['StatefulPRNG']:
|
||||
"""Create a list of independent child generators.
|
||||
|
||||
Args:
|
||||
n_children: non-negative integer.
|
||||
|
||||
Returns:
|
||||
A list of length ``n_children`` containing new independent ``StatefulPRNG`` instances
|
||||
spawned from the original instance.
|
||||
|
||||
Examples:
|
||||
>>> from jax.experimental import random
|
||||
>>> rng = random.stateful_rng(123)
|
||||
>>> child_rngs = rng.spawn(2)
|
||||
>>> [r.integers(0, 10, 2) for r in child_rngs]
|
||||
[Array([4, 5], dtype=int32), Array([2, 1], dtype=int32)]
|
||||
|
||||
See also:
|
||||
- :meth:`jax.experimental.random.StatefulPRNG.split`: this is similar to spawn, but returns
|
||||
a single mapped :class:`jax.experimental.random.StatefulPRNG`` which can be passed to
|
||||
:func:`jax.vmap`.
|
||||
"""
|
||||
return [self.__class__(key, ref.new_ref(0)) for key in self.key(n_children)]
|
||||
|
||||
|
||||
def stateful_rng(seed: typing.ArrayLike | None = None, *,
|
||||
impl: random.PRNGSpecDesc | None = None) -> StatefulPRNG:
|
||||
"""
|
||||
Experimental stateful RNG with implicitly-updated state.
|
||||
|
||||
This implements a stateful PRNG API similar to :func:`numpy.random.default_rng`.
|
||||
It is compatible with JAX transformations like :func:`~jax.jit` and others,
|
||||
with a few exceptions mentioned in the Notes below.
|
||||
|
||||
.. note::
|
||||
|
||||
This stateful PRNG API is a convenience wrapper around JAX's classic
|
||||
stateless, explicitly updated PRNG, described in :mod:`jax.random`.
|
||||
For performance-critical applications, it is recommended to use
|
||||
:func:`jax.random.key` with explicit random state semantics.
|
||||
|
||||
For a discussion of design considerations for this API, refer to
|
||||
:ref:`stateful-randomness-jep`.
|
||||
|
||||
Args:
|
||||
seed: an optional 64- or 32-bit integer used as the value of the key.
|
||||
This must be specified if the generator is instantiated within transformed
|
||||
code; when used at the top level of the program, it may be omitted in
|
||||
which case the RNG will be seeded using the default NumPy seeding.
|
||||
impl: optional string specifying the PRNG implementation (e.g.
|
||||
``'threefry2x32'``)
|
||||
|
||||
Returns:
|
||||
A :class:`~jax.experimental.random.StatefulPRNG` object, with methods for generating
|
||||
random values.
|
||||
|
||||
Notes:
|
||||
The :class:`~jax.experimental.random.StatefulPRNG` object created by this method uses
|
||||
:func:`~jax.Ref` objects to allow implicit updates of state, and thus
|
||||
inherits some of its limitiations. For example:
|
||||
|
||||
- :class:`StatefulPRNG` objects cannot be among the return values of functions
|
||||
wrapped in JIT or other JAX transformations. This means in particular
|
||||
they cannot be used as `carry` values for :func:`jax.lax.scan`,
|
||||
:func:`jax.lax.while_loop`, and other JAX control flow.
|
||||
- :class:`StatefulPRNG` objects cannot be used together with
|
||||
:func:`jax.checkpoint` or :func:`jax.remat`; in these cases it's best to
|
||||
use the :meth:`StatefulPRNG.key` method to produce a standard JAX PRNG key.
|
||||
|
||||
Examples:
|
||||
>>> from jax.experimental import random
|
||||
>>> rng = random.stateful_rng(42)
|
||||
|
||||
Repeated draws implicitly update the key:
|
||||
|
||||
>>> rng.uniform()
|
||||
Array(0.5302608, dtype=float32)
|
||||
>>> rng.uniform()
|
||||
Array(0.72766423, dtype=float32)
|
||||
|
||||
This also works under transformations like :func:`jax.jit`:
|
||||
|
||||
>>> import jax
|
||||
>>> jit_uniform = jax.jit(rng.uniform)
|
||||
>>> jit_uniform()
|
||||
Array(0.6672406, dtype=float32)
|
||||
>>> jit_uniform()
|
||||
Array(0.3890121, dtype=float32)
|
||||
|
||||
Keys can be generated directly if desired:
|
||||
|
||||
>>> rng.key()
|
||||
Array((), dtype=key<fry>) overlaying:
|
||||
[2954079971 3276725750]
|
||||
>>> rng.key()
|
||||
Array((), dtype=key<fry>) overlaying:
|
||||
[2765691542 824333390]
|
||||
"""
|
||||
if seed is None:
|
||||
if not core.trace_ctx.is_top_level():
|
||||
raise TypeError(
|
||||
"When used within transformed code, jax.experimental.random.stateful_rng()"
|
||||
" requires an explicit seed to be set.")
|
||||
entropy = np.random.SeedSequence().entropy
|
||||
assert isinstance(entropy, int)
|
||||
seed = np.int64(entropy & np.iinfo(np.int64).max)
|
||||
assert seed is not None
|
||||
return StatefulPRNG(
|
||||
_base_key=random.key(seed, impl=impl),
|
||||
_counter=ref.new_ref(0)
|
||||
)
|
||||
Reference in New Issue
Block a user