Files
2026-05-06 19:47:31 +07:00

198 lines
5.9 KiB
Python

# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax._src import dtypes
from jax._src.core import ShapedArray
from jax._src.lib import _jax
import numpy as np
# TypedInt, TypedFloat, and TypedComplex are subclasses of int, float, and
# complex that carry a JAX dtype. Canonicalization forms these types from int,
# float, and complex. Repeated canonicalization, including under different
# jax_enable_x64 modes, preserves the dtype.
# Precomputed weak scalar avals
_weak_int32_aval = ShapedArray((), np.dtype(np.int32), weak_type=True)
_weak_int64_aval = ShapedArray((), np.dtype(np.int64), weak_type=True)
_weak_float32_aval = ShapedArray((), np.dtype(np.float32), weak_type=True)
_weak_float64_aval = ShapedArray((), np.dtype(np.float64), weak_type=True)
_weak_complex64_aval = ShapedArray((), np.dtype(np.complex64), weak_type=True)
_weak_complex128_aval = ShapedArray((), np.dtype(np.complex128), weak_type=True)
class TypedInt(int):
dtype: np.dtype
aval: ShapedArray
def __new__(cls, value: int, dtype: np.dtype):
v = super().__new__(cls, value)
v.dtype = dtype
if dtype == np.dtype(np.int32):
v.aval = _weak_int32_aval
elif dtype == np.dtype(np.int64):
v.aval = _weak_int64_aval
else:
v.aval = ShapedArray((), dtype, weak_type=True)
return v
def __repr__(self):
return f'TypedInt({int(self)}, dtype={self.dtype.name})'
def __getnewargs__(self):
return (int(self), self.dtype)
class TypedFloat(float):
dtype: np.dtype
aval: ShapedArray
def __new__(cls, value: float, dtype: np.dtype):
v = super().__new__(cls, value)
v.dtype = dtype
if dtype == np.dtype(np.float32):
v.aval = _weak_float32_aval
elif dtype == np.dtype(np.float64):
v.aval = _weak_float64_aval
else:
v.aval = ShapedArray((), dtype, weak_type=True)
return v
def __repr__(self):
return f'TypedFloat({float(self)}, dtype={self.dtype.name})'
def __str__(self):
return str(float(self))
def __getnewargs__(self):
return (float(self), self.dtype)
class TypedComplex(complex):
dtype: np.dtype
aval: ShapedArray
def __new__(cls, value: complex, dtype: np.dtype):
v = super().__new__(cls, value)
v.dtype = dtype
if dtype == np.dtype(np.complex64):
v.aval = _weak_complex64_aval
elif dtype == np.dtype(np.complex128):
v.aval = _weak_complex128_aval
else:
v.aval = ShapedArray((), dtype, weak_type=True)
return v
def __repr__(self):
return f'TypedComplex({complex(self)}, dtype={self.dtype.name})'
def __getnewargs__(self):
return (complex(self), self.dtype)
typed_scalar_types: set[type] = {TypedInt, TypedFloat, TypedComplex}
class TypedNdArray(np.ndarray):
"""A TypedNdArray is a host-side array used by JAX during tracing.
TypedNdArray is a subclass of np.ndarray that carries additional JAX type
information:
* its type is not canonicalized by JAX, irrespective of the jax_enable_x64
mode
* it can be weakly typed.
"""
__slots__ = ('_aval', '_weak_type')
def __new__(cls, val: np.ndarray, aval: ShapedArray | None = None):
obj = np.asarray(val).view(cls)
if aval is not None:
obj._aval = aval
return obj
def __array_finalize__(self, obj):
self._aval = None
self._weak_type = (obj.aval.weak_type
if isinstance(obj, TypedNdArray) else False)
@property
def aval(self) -> ShapedArray:
result = self._aval
if result is None:
# It is possible that multiple threads might race to reach here. However
# this seems safe since they will all set the same value.
result = ShapedArray(self.shape, self.dtype, weak_type=self._weak_type)
self._aval = result
return result
@property
def weak_type(self) -> bool:
return self.aval.weak_type
@property
def val(self) -> np.ndarray:
return np.asarray(self)
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
inputs = tuple(
np.asarray(x) if isinstance(x, TypedNdArray) else x for x in inputs
)
if 'out' in kwargs:
kwargs['out'] = tuple(
np.asarray(x) if isinstance(x, TypedNdArray) else x
for x in kwargs['out']
)
return getattr(ufunc, method)(*inputs, **kwargs)
def __repr__(self):
prefix = 'TypedNdArray('
if self.aval.weak_type:
dtype_str = f'dtype={self.dtype.name}, weak_type=True)'
else:
dtype_str = f'dtype={self.dtype.name})'
line_width = np.get_printoptions()['linewidth']
if self.size == 0:
s = f'[], shape={self.shape}'
else:
s = np.array2string(
np.asarray(self),
prefix=prefix,
suffix=',',
separator=', ',
max_line_width=line_width,
)
last_line_len = len(s) - s.rfind('\n') + 1
sep = ' '
if last_line_len + len(dtype_str) + 1 > line_width:
sep = ' ' * len(prefix)
return f'{prefix}{s},{sep}{dtype_str}'
def __reduce__(self):
return (TypedNdArray, (np.asarray(self), self.aval.weak_type))
def __getnewargs__(self):
return (np.asarray(self), self.aval.weak_type)
_jax.set_typed_ndarray_type(TypedNdArray)
dtypes.register_type_whose_dtype_should_not_be_canonicalized(TypedNdArray)
_jax.set_typed_int_type(TypedInt)
_jax.set_typed_float_type(TypedFloat)
_jax.set_typed_complex_type(TypedComplex)
for _typ in typed_scalar_types:
dtypes.register_weak_scalar_type(_typ)
dtypes.register_type_whose_dtype_should_not_be_canonicalized(_typ)