This commit is contained in:
2026-05-06 19:47:31 +07:00
parent 94d8682530
commit 12dbb7731b
9963 changed files with 2747894 additions and 0 deletions
@@ -0,0 +1,14 @@
# Copyright 2024 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe GenAI Python API."""
@@ -0,0 +1,23 @@
# Copyright 2024 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe Tasks GenAI Bundler API."""
import mediapipe.tasks.python.genai.bundler.llm_bundler
BundleConfig = llm_bundler.BundleConfig
create_bundle = llm_bundler.create_bundle
# Remove unnecessary modules to avoid duplication in API docs.
del llm_bundler
@@ -0,0 +1,84 @@
# Copyright 2024 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions to perform llm packing."""
import dataclasses
import enum
from typing import List
from mediapipe.tasks.python.metadata.metadata_writers import model_asset_bundle_utils
from mediapipe.tasks.cc.genai.inference.proto import llm_params_pb2
@dataclasses.dataclass(frozen=True)
class BundleConfig:
"""Config for LLM Bundler.
Attributes:
tflite_model: Path to the multi-signature tflite model with "prefill" and
"decode" signatures converted using ODML Transformers APIs.
tokenizer_model: Path to the tokenizer model. Currently only SentencePience
tokenizer is supported. As such, tokenizer.model proto is expected to be
passed here.
start_token: Token that will be used to signify the beginning of a sequence.
stop_tokens: Tokens that will be used to signify the end of a sequence.
output_filename: Name of the generated `.task` file containg the Bundle.
enable_bytes_to_unicode_mapping: Enables GPT-2 style bytes to unicode
mapping. For more details see:
https://github.com/openai/gpt-2/blob/master/src/encoder.py#L9
"""
tflite_model: str
tokenizer_model: str
start_token: str
stop_tokens: List[str]
output_filename: str
enable_bytes_to_unicode_mapping: bool = False
class _BundleTags(enum.Enum):
"""Bundle tags."""
TF_LITE_PREFILL_DECODE = 1
TOKENIZER_MODEL = 2
METADATA = 3
def create_bundle(config: BundleConfig):
"""Creates a bundle from the given config."""
artifacts = {}
with open(config.tflite_model, "rb") as f:
artifacts[_BundleTags.TF_LITE_PREFILL_DECODE.name] = f.read()
with open(config.tokenizer_model, "rb") as f:
artifacts[_BundleTags.TOKENIZER_MODEL.name] = f.read()
params = llm_params_pb2.LlmParameters()
params.start_token = config.start_token
params.stop_tokens.extend(config.stop_tokens)
if config.enable_bytes_to_unicode_mapping:
params.input_output_normalizations.append(
llm_params_pb2.LlmParameters.INPUT_OUTPUT_NORMALIZATION_BYTES_TO_UNICODE
)
artifacts[_BundleTags.METADATA.name] = params.SerializeToString()
output_filename = config.output_filename
if not output_filename.endswith(".task"):
output_filename = config.output_filename + ".task"
model_asset_bundle_utils.create_model_asset_bundle(
artifacts,
output_filename,
)
@@ -0,0 +1,64 @@
# Copyright 2024 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for llm_bundler."""
import os
import zipfile
from absl.testing import absltest
from mediapipe.tasks.python.genai.bundler import llm_bundler
class LlmBundlerTest(absltest.TestCase):
def _create_test_bundle(self, out_dir: str):
"""Helper function to create test bundle."""
tflite_file_path = os.path.join(out_dir, "test.tflite")
with open(tflite_file_path, "w") as f:
f.write("tflite_model")
sp_model_file_path = os.path.join(out_dir, "sp.model")
with open(sp_model_file_path, "w") as f:
f.write("sp_model")
output_file = os.path.join(out_dir, "test.task")
config = llm_bundler.BundleConfig(
tflite_model=tflite_file_path,
tokenizer_model=sp_model_file_path,
start_token="BOS",
stop_tokens=["EOS1", "EOS2"],
output_filename=output_file,
enable_bytes_to_unicode_mapping=True,
)
llm_bundler.create_bundle(config)
return output_file
def test_can_create_bundle(self):
tempdir = self.create_tempdir()
output_file = self._create_test_bundle(tempdir.full_path)
self.assertTrue(os.path.exists(output_file))
def test_verify_content(self):
tempdir = self.create_tempdir()
output_file = self._create_test_bundle(tempdir.full_path)
with zipfile.ZipFile(output_file) as zip_file:
self.assertLen(zip_file.filelist, 3)
self.assertEqual(zip_file.filelist[0].filename, "TF_LITE_PREFILL_DECODE")
self.assertEqual(zip_file.filelist[1].filename, "TOKENIZER_MODEL")
self.assertEqual(zip_file.filelist[2].filename, "METADATA")
if __name__ == "__main__":
absltest.main()
@@ -0,0 +1,24 @@
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe Tasks GenAI Converter API."""
import mediapipe.tasks.python.genai.converter.llm_converter
ConversionConfig = llm_converter.ConversionConfig
convert_checkpoint = llm_converter.convert_checkpoint
# Remove unnecessary modules to avoid duplication in API docs.
del mediapipe
del llm_converter
@@ -0,0 +1,175 @@
# Copyright 2024 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines a couple base classes for the conversion/quantization process."""
from typing import Iterator
import os
from typing import Dict, List, Optional, Tuple
import numpy as np
class QuantizationAction:
"""Container of the tensor values and its corresponding quantization settings.
The contrainer is responsible for hosting all of the information that is
required to execute the weight-only quantization.
Attributes:
tensor_name: A string that represents the input tensor name.
tensor_value: A numpy array that contains the unquantized tensor values.
target_name: A string that represents the updated tensor name.
quantize_axis: A list of integers representing the dimensions to be
quantized along. For example, if an input tensor has shape [128, 256] and
the quantize_axis==[0], it means the quantization happens along the 0-th
dimension, resulting in [256] scaling factors.
quantize_bits: An integer that specifies the target quantization bits. It
currently only supports either 8 or 4 bits.
pack_dim: An integer specifying which dimension to pack the quantized bits.
This is only applicable when the quantize_bits == 4.
"""
def __init__(
self,
tensor_name: str,
tensor_value: Optional[np.ndarray] = None,
target_name: Optional[str] = None,
quantize_axis: Optional[List[int]] = None,
quantize_bits: Optional[int] = None,
pack_dim: Optional[int] = 0,
):
"""Initializes the model attributes."""
self.tensor_name = tensor_name
self.tensor_value = tensor_value
self.target_name = target_name
self.quantize_axis = quantize_axis
self.quantize_bits = quantize_bits
self.pack_dim = pack_dim
def __str__(self) -> str:
output_string = "QuantizationAction(\n"
output_string += f" tensor_name: {self.tensor_name}\n"
output_string += f" target_name: {self.target_name}\n"
output_string += f" quantize_axis: {self.quantize_axis}\n"
output_string += f" quantize_bits: {self.quantize_bits}\n"
output_string += f" pack_dim: {self.pack_dim}\n"
if self.tensor_value is not None:
output_string += f" tensor_value: {self.tensor_value.shape}\n"
output_string += ")\n"
return output_string
class CkptLoaderBase:
"""Base class for loading the checkpoint.
This class is responsible for loading the checkpoint files into the layer
weight tensors (as numpy arrays) + quantization setting information (8/4
bits). The returned data should be a list of QuantizationAction that describes
how to quantize each layer weights.
"""
def __init__(
self,
ckpt_path: str,
is_symmetric: bool,
attention_quant_bits: int,
feedforward_quant_bits: int,
embedding_quant_bits: int,
):
"""Initializes the loader.
Args:
ckpt_path: The filepath to the checkpoint.
is_symmetric: Whether to apply symmetric or asymmetric quantization.
attention_quant_bits: An integer that specify the target quantization bits
(support 8 or 4) for the attention layers.
feedforward_quant_bits: An integer that specify the target quantization
bits (support 8 or 4) for the feedforward layers in each Transformer
blocks.
embedding_quant_bits: An integer that specify the target quantization bits
(support 8 or 4) for the embedding (and the final projection) layers.
"""
self._ckpt_path = ckpt_path
self._is_symmetric = is_symmetric
self._attention_quant_bits = attention_quant_bits
self._feedforward_quant_bits = feedforward_quant_bits
self._embedding_quant_bits = embedding_quant_bits
def load_to_actions(
self,
) -> Iterator[Optional[List[QuantizationAction]]]:
"""Loads the checkpoint and returns the quantization actions."""
raise NotImplementedError("The load_to_actions method is not implemented.")
class LayerActionMapperBase:
"""Base class for mapping the layer weights to quantization actions.
This class is responsible for mapping from each layer to its corresponding
quantization information (e.g. target quantization bits / updated tensor
name...).
"""
def __init__(
self,
is_symmetric: bool,
attention_quant_bits: int,
feedforward_quant_bits: int,
embedding_quant_bits: int,
backend: str,
):
self._is_symmetric = is_symmetric
self._attention_quant_bits = attention_quant_bits
self._feedforward_quant_bits = feedforward_quant_bits
self._embedding_quant_bits = embedding_quant_bits
self._backend = backend
def map_to_actions(
self, layer_name: str
) -> Optional[List[QuantizationAction]]:
"""Maps the layer weights to quantization actions.
Args:
layer_name: A string representing the name of the layer weight. Note that
it is expected the layer information is contained in the name which is
enough to determine the target quantization information. Any child class
is expected to implement this function.
"""
raise NotImplementedError("The map_to_actions method is not implemented.")
class ModelWriterBase:
"""Base class for writing the quantized model.
This class is responsible for taking a dictionary of the quantized
tensors/names and writing them into the format that can be loaded by the
on-device inference engine.
"""
def __init__(self, output_dir: str, backend: str):
"""Initializes the class.
Args:
output_dir: A string that represents the output directory to write the
resulting file(s).
backend: A string that represents the target backend to run the output
file(s).
"""
self._output_dir = output_dir
if not os.path.exists(self._output_dir):
os.mkdir(self._output_dir)
self._backend = backend
def write_variables(self, variables: Dict[str, Tuple[np.ndarray, bool]]):
raise NotImplementedError("The write_variables method is not implemented.")
@@ -0,0 +1,79 @@
# Copyright 2024 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility library that helps create the converter instances."""
from mediapipe.tasks.python.genai.converter import converter_base
from mediapipe.tasks.python.genai.converter import pytorch_converter
from mediapipe.tasks.python.genai.converter import safetensors_converter
from mediapipe.tasks.python.genai.converter import weight_bins_writer
def create_ckpt_loader(
ckpt_format: str, *args, **kwargs
) -> converter_base.CkptLoaderBase:
"""Creates the checkpoint loader.
Args:
ckpt_format: A string that indicates which input checkpoint format is.
*args: Additional arguments to be passed into the loader.
**kwargs: Additional arguments to be passed into the loader.
Returns:
A created CkptLoader instance.
"""
del args
if ckpt_format == "pytorch":
return pytorch_converter.PytorchCkptLoader(
ckpt_path=kwargs["ckpt_path"],
is_symmetric=kwargs["is_symmetric"],
attention_quant_bits=kwargs["attention_quant_bits"],
feedforward_quant_bits=kwargs["feedforward_quant_bits"],
embedding_quant_bits=kwargs["embedding_quant_bits"],
special_model=kwargs["special_model"],
backend=kwargs["backend"],
)
elif ckpt_format == "safetensors":
return safetensors_converter.SafetensorsCkptLoader(
ckpt_path=kwargs["ckpt_path"],
is_symmetric=kwargs["is_symmetric"],
attention_quant_bits=kwargs["attention_quant_bits"],
feedforward_quant_bits=kwargs["feedforward_quant_bits"],
embedding_quant_bits=kwargs["embedding_quant_bits"],
special_model=kwargs["special_model"],
backend=kwargs["backend"],
)
else:
raise ValueError(f"Unknown checkpoint format: {ckpt_format}")
def create_writer(
writer_type: str, *args, **kwargs
) -> converter_base.ModelWriterBase:
"""Creates the model writer.
Args:
writer_type: A string the indicates which model writer to create.
*args: Additional arguments to be passed into the loader.
**kwargs: Additional arguments to be passed into the loader.
Returns:
A created ModelWriter instance.
"""
del args
if writer_type == "weight_bins":
return weight_bins_writer.WeightBinsWriter(
output_dir=kwargs["output_dir"], backend=kwargs["backend"]
)
else:
raise ValueError(f"Unknown writer type: {writer_type}")
@@ -0,0 +1,316 @@
"""Functions to perform the checkpoint conversion."""
import contextlib
import os
from typing import List, Optional
from absl import logging
from mediapipe.python._framework_bindings import model_ckpt_util
from mediapipe.tasks.python.genai.converter import converter_base
from mediapipe.tasks.python.genai.converter import converter_factory
from mediapipe.tasks.python.genai.converter import quantization_util
class ConversionConfig(object):
"""Config for checkpoint conversion.
Attributes:
input_ckpt: Directory or path for the input checkpoint.
ckpt_format: Checkpoint format, e.g. 'safetensors', 'pytorch'.
model_type: Name of the model, e.g. GEMMA_2B.
backend: Target backend to run the model. Can be either "cpu" or "gpu".
output_dir: Where the output file(s) to be stored.
is_symmetric: Whether to quantize symmetrically.
attention_quant_bits: Target quantization bits for the attention layers.
feedforward_quant_bits: Target quantization bits for the feedforward layers.
embedding_quant_bits: Target quantization bits for the embedding layers.
combine_file_only: Whether to combine the weight files only (assuming the
weight files are already existed).
vocab_model_file: The file path to the 1) SentencePiece vocab model; 2)
Hugging Face BPE tokenizer files; 1) is applicable for the Gemma model and
2) is applicable for other models. When 2) is used, the provided path is
expected to point to a directory that contains both tokenizer.json and
tokenizer_config.json files.
output_tflite_file: (optional) the output tflite filename. If not provided,
the output will be `model.tflite` stored in the output_dir.
fp16_scale: A scalar value between [0, 1]. Some models can run into
activation overflow issue when running in 16-bit floating point mode. To
solve this, we need to scale down the weights of certain layers. See
go/llm-on-device-fp16 for more detailed explanation.
lora_ckpt: The directory or path for the lora checkpoint. Required in order
to convert the lora weights.
lora_rank: An integer representing the rank of LoRA. Required in order to
convert the lora weights.If not provided, then the converter assumes there
is no LoRA weights. Note that only the GPU backend supports LoRA.
lora_output_tflite_file: A string indicating the name of the generated
tflite file for the LoRA weight. Only applicable when the lora_rank is not
zero.
"""
def __init__(
self,
input_ckpt: str,
ckpt_format: str,
model_type: str,
backend: str,
output_dir: str,
is_symmetric: bool = True,
attention_quant_bits: int = 8,
feedforward_quant_bits: int = 8,
embedding_quant_bits: int = 8,
combine_file_only: bool = False,
vocab_model_file: str = '',
output_tflite_file: Optional[str] = None,
fp16_scale: Optional[float] = None,
lora_ckpt: Optional[str] = None,
lora_rank: Optional[int] = None,
lora_output_tflite_file: Optional[str] = None,
):
self.input_ckpt = input_ckpt
self.ckpt_format = ckpt_format
self.model_type = model_type
self.backend = backend
if os.path.isfile(output_dir):
raise ValueError('Output directory mush not point to an existing file.')
if not os.path.isdir(output_dir):
logging.info('Creating output directory: %s', output_dir)
os.makedirs(output_dir, exist_ok=True)
self.output_dir = output_dir
self.is_symmetric = is_symmetric
self.attention_quant_bits = attention_quant_bits
self.feedforward_quant_bits = feedforward_quant_bits
self.embedding_quant_bits = embedding_quant_bits
self.combine_file_only = combine_file_only
self.vocab_model_file = vocab_model_file
if output_tflite_file:
parent_dir = os.path.dirname(output_tflite_file)
if not os.path.isdir(parent_dir):
logging.info('Creating tflite parent directory: %s', parent_dir)
os.makedirs(parent_dir, exist_ok=True)
self.output_tflite_file = output_tflite_file
else:
self.output_tflite_file = os.path.join(output_dir, 'model.tflite')
self.fp16_scale = None
self.lora_ckpt = lora_ckpt
self.lora_rank = lora_rank
self.lora_output_tflite_file = lora_output_tflite_file
if (self.lora_ckpt is None) ^ (self.lora_rank is None):
raise ValueError(
'lora_ckpt and lora_rank must be either both provided or both not'
' provided.'
)
if self.lora_rank is not None:
if backend == 'cpu':
raise ValueError('LoRA is not supported for CPU backend.')
lora_applicable_models = ['GEMMA_2B', 'PHI_2']
if model_type not in lora_applicable_models:
raise ValueError(
'LoRA is only applicable for the model_type:'
f' {", ".join(lora_applicable_models)}, but get model_type:'
f' {model_type}.'
)
def quantize_by_actions(
actions: List[converter_base.QuantizationAction],
backend: str,
is_symmetric: bool,
):
"""Quantizes the weights by actions.
Args:
actions: A list of QuantizationAction that contains the information and
tensor values to be quantized.
backend: Target backend to run the model. Can be either "cpu" or "gpu".
is_symmetric: Whether to quantize symmetrically.
Returns:
A dictionary that maps from the updated tensor names to the quantized
tensor values + a boolean that indicates whether the tensor values need to
be packed (only applicable for the 4-bit quantized weights).
"""
output_tensors = {}
for action in actions:
if action.quantize_axis:
pack = action.quantize_bits == 4
if is_symmetric:
target_var, scale = quantization_util.quantize_tensor(
var=action.tensor_value,
axis=action.quantize_axis,
sym=is_symmetric,
number_bits=action.quantize_bits,
)
output_tensors[action.target_name] = (target_var, pack)
output_tensors[action.target_name + '_quantized_scale'] = (scale, False)
zp = None
else:
target_var, scale, zp = quantization_util.quantize_tensor(
var=action.tensor_value,
axis=action.quantize_axis,
sym=is_symmetric,
number_bits=action.quantize_bits,
)
if backend == 'cpu' and pack:
target_var, scale, zp = quantization_util.update_to_uint4(
target_var, scale, zp
)
output_tensors[action.target_name] = (target_var, pack)
output_tensors[action.target_name + '_quantized_scale'] = (scale, False)
if zp is not None:
output_tensors[action.target_name + '_quantized_zp'] = (zp, False)
else:
output_tensors[action.target_name] = (action.tensor_value, False)
return output_tensors
def combined_weight_bins_to_tflite(
model_type: str,
backend: str,
weight_path: str,
output_tflite_file: str,
vocab_model_file: str,
lora_rank: Optional[int] = None,
lora_weight_path: Optional[str] = None,
lora_output_tflite_file: Optional[str] = None,
):
"""Combines weight files to tflite file."""
if backend == 'cpu':
if lora_rank is not None:
logging.fatal('LoRA is not supported for CPU backend.')
model_ckpt_util.GenerateCpuTfLite(
model_type,
weight_path,
vocab_model_file,
True,
output_tflite_file,
)
elif backend == 'gpu':
model_ckpt_util.GenerateGpuTfLite(
model_type,
weight_path,
vocab_model_file,
True,
output_tflite_file,
0 if lora_rank is None else lora_rank,
'' if lora_weight_path is None else lora_weight_path,
'' if lora_output_tflite_file is None else lora_output_tflite_file,
)
else:
raise ValueError('Unsupported backend: %s' % backend)
def convert_bpe_vocab(vocab_model_file: str, output_dir: str) -> str:
if not os.path.isdir(vocab_model_file):
raise ValueError(
'The input BPE vocab model file path is expected to be a directory that'
' conatins both tokenizer.json and tokenizer_config.json files.'
)
output_vocab_file = os.path.join(output_dir, 'spm.model')
model_ckpt_util.ConvertHfTokenizer(vocab_model_file, output_vocab_file)
return output_vocab_file
@contextlib.contextmanager
def filemanager(filename: str, mode: str):
try:
with open(filename, mode) as f:
yield f
finally:
pass
def sort_layer_info(layer_info_file: str) -> None:
"""Loads and sorts the layer info file."""
layer_info = []
with filemanager(layer_info_file, 'r') as finfo:
for line in finfo:
line = line.strip()
if line:
layer_info.append(line)
layer_info = list(set(layer_info))
layer_info.sort()
with filemanager(layer_info_file, 'w') as finfo:
for line in layer_info:
finfo.write(line + '\n')
finfo.write('\n')
def maybe_quantize_and_write_tensors_to_bins(
ckpt_loader: converter_base.CkptLoaderBase,
config: ConversionConfig,
) -> None:
"""Quantizes the weight tensors according to the loader and writes them to bins."""
actions = ckpt_loader.load_to_actions()
for action in actions:
# Quantize the weight
quantized_tensors = quantize_by_actions(
action, config.backend, config.is_symmetric
)
del action
# Write the tensors into file(s).
writer = converter_factory.create_writer(
writer_type='weight_bins',
output_dir=config.output_dir,
backend=config.backend,
)
writer.write_variables(quantized_tensors)
del quantized_tensors
del writer
def convert_checkpoint(config: ConversionConfig) -> None:
"""Converts the checkpoint to tflite file."""
logging.info('input folder: %s', config.input_ckpt)
if os.path.isdir(config.vocab_model_file):
vocab_model_path = convert_bpe_vocab(
config.vocab_model_file, config.output_dir
)
else:
vocab_model_path = config.vocab_model_file
if not config.combine_file_only:
# Load the layer weights and prepare the quantization configurations.
loader = converter_factory.create_ckpt_loader(
config.ckpt_format,
ckpt_path=config.input_ckpt,
is_symmetric=config.is_symmetric,
backend=config.backend,
attention_quant_bits=config.attention_quant_bits,
feedforward_quant_bits=config.feedforward_quant_bits,
embedding_quant_bits=config.embedding_quant_bits,
special_model=config.model_type,
fp16_scale=config.fp16_scale,
)
maybe_quantize_and_write_tensors_to_bins(loader, config)
if config.lora_ckpt is not None and config.lora_ckpt != config.input_ckpt:
# If lora ckpt and the input ckpt is the same. The lora conversion is
# handled in the previous loader.
lora_loader = converter_factory.create_ckpt_loader(
config.ckpt_format,
ckpt_path=config.lora_ckpt,
is_symmetric=config.is_symmetric,
backend=config.backend,
attention_quant_bits=None,
feedforward_quant_bits=None,
embedding_quant_bits=None,
special_model=config.model_type,
)
maybe_quantize_and_write_tensors_to_bins(lora_loader, config)
sort_layer_info(os.path.join(config.output_dir, 'layer_info.txt'))
combined_weight_bins_to_tflite(
config.model_type,
config.backend,
weight_path=config.output_dir,
output_tflite_file=config.output_tflite_file,
vocab_model_file=vocab_model_path,
lora_rank=config.lora_rank,
lora_weight_path=config.output_dir,
lora_output_tflite_file=config.lora_output_tflite_file,
)
@@ -0,0 +1,318 @@
# Copyright 2024 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CkptLoader implementation for loading the Pytorch file."""
from typing import Iterator
import enum
import os
from typing import List, Optional
import numpy as np
import torch
from mediapipe.tasks.python.genai.converter import converter_base
class _PytorchReader:
"""Pytorch reader."""
def __init__(self, model_path: str):
if not os.path.exists(model_path):
raise ValueError(f"{model_path} does not exists.")
self._model = torch.load(model_path, map_location=torch.device("cpu"))
def read_tensor_as_numpy(self, tensor_name) -> np.ndarray:
tensor = (
self._model[tensor_name]
.to(torch.float32)
.t()
.contiguous()
.detach()
.cpu()
.numpy()
)
return tensor
def get_tensor_names(self) -> List[str]:
names = list(self._model.keys())
return names
class LayerType(enum.Enum):
"""Enum for layer type."""
NONE = 0
ATTENTION = 1 # Layer is part of the attention module.
FEEDFORWARD = 2 # Layer is part of the feedforward module in the Transformer.
EMBEDDING = 3 # Layer is the embedding lookup or final projection layer.
LAYER_NORM = (
4 # Layer is layer normalization before and after attention layer.
)
@classmethod
def get_layer_type(cls, layer_name: str):
"""Gets the layer type of the given layer name."""
ffn_layers = [
"mlp",
]
attn_layers = [
"self_attention",
]
emb_layers = [
"word_embeddings",
"lm_head",
]
layer_norms = [
"input_layernorm",
"post_attention_layernorm",
"ln_f",
]
if any(sub_name in layer_name for sub_name in attn_layers):
return LayerType.ATTENTION
if any(sub_name in layer_name for sub_name in ffn_layers):
return LayerType.FEEDFORWARD
if any(sub_name in layer_name for sub_name in emb_layers):
return LayerType.EMBEDDING
if any(sub_name in layer_name for sub_name in layer_norms):
return LayerType.LAYER_NORM
else:
return LayerType.NONE
class FalconMapper(converter_base.LayerActionMapperBase):
"""LayerActionMapper for handling the Falcon-rw-1b model."""
def __init__(
self,
is_symmetric: bool,
attention_quant_bits: int,
feedforward_quant_bits: int,
embedding_quant_bits: int,
backend: str,
reader: _PytorchReader,
):
super().__init__(
is_symmetric=is_symmetric,
attention_quant_bits=attention_quant_bits,
feedforward_quant_bits=feedforward_quant_bits,
embedding_quant_bits=embedding_quant_bits,
backend=backend,
)
self._reader = reader
def map_to_actions(
self, layer_name: str
) -> Optional[List[converter_base.QuantizationAction]]:
"""Map the given layer name to actions."""
actions = []
tensor_value = self._reader.read_tensor_as_numpy(layer_name)
if "query_key_value" in layer_name:
qkv_tensors = self._decompose_falcon_qkv(tensor_value)
for tensor, name in zip(qkv_tensors, ["q", "k", "v"]):
decomposed_name = layer_name.replace("query_key_value", name)
action = self._map_to_action_helper(tensor, decomposed_name)
actions.append(action)
else:
actions.append(self._map_to_action_helper(tensor_value, layer_name))
return actions
def _map_to_action_helper(
self, tensor_value: np.ndarray, layer_name: str
) -> converter_base.QuantizationAction:
quantize_axis = None
quantize_bits = None
layer_type = LayerType.get_layer_type(layer_name)
if layer_type != LayerType.LAYER_NORM and layer_name.endswith(".weight"):
layer_type = LayerType.get_layer_type(layer_name)
quantize_axis = [0]
if layer_type == LayerType.FEEDFORWARD:
quantize_bits = self._feedforward_quant_bits
elif layer_type == LayerType.ATTENTION:
quantize_bits = self._attention_quant_bits
if self._backend == "cpu" and ".dense." in layer_name:
tensor_value = np.transpose(tensor_value)
quantize_axis = [1]
elif layer_type == LayerType.EMBEDDING:
quantize_bits = self._embedding_quant_bits
if self._backend == "cpu" and "word_embeddings" in layer_name:
tensor_value = np.transpose(tensor_value)
quantize_axis = [1]
target_name = self.update_target_name(layer_name)
return converter_base.QuantizationAction(
tensor_name=layer_name,
tensor_value=tensor_value,
target_name=target_name,
quantize_axis=quantize_axis,
quantize_bits=quantize_bits,
pack_dim=0,
)
def _decompose_falcon_qkv(self, tensor_value: np.ndarray) -> List[np.ndarray]:
"""Decomposes combined qkv tensor used in falcon model into separate q, k and v tensors."""
chunk_size = 64
hidden_size = 2048
tensor_value = tensor_value.transpose()
q_tensor = np.zeros(
(hidden_size,)
+ ((hidden_size,) if len(tensor_value.shape) == 2 else ()),
dtype=tensor_value.dtype,
)
k_tensor = np.zeros_like(q_tensor, dtype=tensor_value.dtype)
v_tensor = np.zeros_like(k_tensor, dtype=tensor_value.dtype)
j = 0
for i in range(0 * chunk_size, hidden_size * 3, chunk_size * 3):
q_tensor[j : j + chunk_size] = tensor_value[i : i + chunk_size]
j += chunk_size
j = 0
for i in range(1 * chunk_size, hidden_size * 3, chunk_size * 3):
k_tensor[j : j + chunk_size] = tensor_value[i : i + chunk_size]
j += chunk_size
j = 0
for i in range(2 * chunk_size, hidden_size * 3, chunk_size * 3):
v_tensor[j : j + chunk_size] = tensor_value[i : i + chunk_size]
j += chunk_size
return [
np.ascontiguousarray(q_tensor.transpose()),
np.ascontiguousarray(k_tensor.transpose()),
np.ascontiguousarray(v_tensor.transpose()),
]
def update_target_name(self, target_name: str) -> str:
"""Updates the target name to match the tensor name convention."""
layer_type = LayerType.get_layer_type(target_name)
target_name = target_name.replace(
"transformer.h.", "params.lm.transformer.x_layers_"
)
if layer_type == LayerType.FEEDFORWARD:
target_name = target_name.replace(".weight", ".linear.w")
target_name = target_name.replace(".bias", ".bias.b")
target_name = target_name.replace(
"mlp.dense_h_to_4h", "ff_layer.ffn_layer1"
)
target_name = target_name.replace(
"mlp.dense_4h_to_h", "ff_layer.ffn_layer2"
)
elif layer_type == LayerType.ATTENTION:
target_name = target_name.replace("dense", "post")
target_name = target_name.replace(".weight", ".linear.w")
target_name = target_name.replace(".bias", ".bias.b")
elif layer_type == LayerType.EMBEDDING:
target_name = target_name.replace(
"transformer.word_embeddings", "params.lm.token_embedding"
)
target_name = target_name.replace(
"lm_head", "params.lm.softmax.logits_ffn"
)
target_name = target_name.replace(".weight", ".w")
elif layer_type == LayerType.LAYER_NORM:
target_name = target_name.replace("input_layernorm", "pre_layer_norm")
target_name = target_name.replace(
"pre_layer_norm.weight", "pre_layer_norm.scale"
)
if self._backend == "cpu":
target_name = target_name.replace(
"post_attention_layernorm", "ff_layer.pre_layer_norm"
)
target_name = target_name.replace(
"ff_layer.pre_layer_norm.weight", "ff_layer.pre_layer_norm.scale"
)
else:
target_name = target_name.replace(
"post_attention_layernorm", "post_layer_norm"
)
target_name = target_name.replace(
"post_layer_norm.weight", "post_layer_norm.scale"
)
target_name = target_name.replace(
"transformer.ln_f.weight", "params.lm.final_ln.scale"
)
target_name = target_name.replace(
"transformer.ln_f.bias", "params.lm.final_ln.bias"
)
return target_name
class PytorchCkptLoader(converter_base.CkptLoaderBase):
"""CkptLoader implementation for loading the Pytorch model."""
def __init__(
self,
ckpt_path: str,
is_symmetric: bool,
attention_quant_bits: int,
feedforward_quant_bits: int,
embedding_quant_bits: int,
special_model: str,
backend: str,
):
"""Initializes the loader.
Args:
ckpt_path: The filepath to the safetensors file.
is_symmetric: Whether to apply symmetric or asymmetric quantization.
attention_quant_bits: An integer that specify the target quantization bits
(support 8 or 4) for the attention layers.
feedforward_quant_bits: An integer that specify the target quantization
bits (support 8 or 4) for the feedforward layers in each Transformer
blocks.
embedding_quant_bits: An integer that specify the target quantization bits
(support 8 or 4) for the embedding (and the final projection) layers.
special_model: A string that indicates which input model is and whether
any special treatment is needed.
backend: A string indicating the backend used when converting this model.
Valid options are "cpu" and "gpu".
"""
super().__init__(
ckpt_path,
is_symmetric,
attention_quant_bits,
feedforward_quant_bits,
embedding_quant_bits,
)
self._special_model = special_model
self._reader = _PytorchReader(ckpt_path)
if special_model in ["FALCON_RW_1B"]:
self.mapper = FalconMapper(
is_symmetric,
attention_quant_bits,
feedforward_quant_bits,
embedding_quant_bits,
backend,
self._reader,
)
else:
raise ValueError(f"Unknown special model: {special_model}")
def load_to_actions(
self,
) -> Iterator[List[converter_base.QuantizationAction]]:
tensor_names = self._reader.get_tensor_names()
for tensor_name in tensor_names:
tensor_actions = self.mapper.map_to_actions(tensor_name)
if tensor_actions is None:
continue
yield tensor_actions
@@ -0,0 +1,86 @@
# Copyright 2024 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for pytorch_converter."""
import os
from absl.testing import absltest
from absl.testing import parameterized
from mediapipe.tasks.python.genai.converter import pytorch_converter
from mediapipe.tasks.python.test import test_utils
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/text'
_PYTORCH_FILE = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, 'falcon_rw_1b_test_weight.pt')
)
class PytorchConverterTest(parameterized.TestCase):
VARIABLE_NAMES = [
'transformer.word_embeddings.weight',
'transformer.h.0.input_layernorm.weight',
'transformer.h.0.input_layernorm.bias',
'transformer.h.0.self_attention.query_key_value.weight',
'transformer.h.0.self_attention.query_key_value.bias',
'transformer.h.0.self_attention.dense.weight',
'transformer.h.0.self_attention.dense.bias',
'transformer.h.0.post_attention_layernorm.weight',
'transformer.h.0.post_attention_layernorm.bias',
'transformer.h.0.mlp.dense_h_to_4h.weight',
'transformer.h.0.mlp.dense_h_to_4h.bias',
'transformer.h.0.mlp.dense_4h_to_h.weight',
'transformer.h.0.mlp.dense_4h_to_h.bias',
'transformer.ln_f.weight',
'transformer.ln_f.bias',
'lm_head.weight',
]
def test_init(self):
loader = pytorch_converter.PytorchCkptLoader(
ckpt_path=_PYTORCH_FILE,
is_symmetric=True,
attention_quant_bits=8,
feedforward_quant_bits=8,
embedding_quant_bits=8,
special_model='FALCON_RW_1B',
backend='cpu',
)
self.assertEqual(loader._ckpt_path, _PYTORCH_FILE)
self.assertEqual(loader._is_symmetric, True)
self.assertEqual(loader._attention_quant_bits, 8)
self.assertEqual(loader._feedforward_quant_bits, 8)
@parameterized.product(
quant_bits=(4, 8),
)
def test_load_to_actions(self, quant_bits):
loader = pytorch_converter.PytorchCkptLoader(
ckpt_path=_PYTORCH_FILE,
is_symmetric=True,
attention_quant_bits=8,
feedforward_quant_bits=quant_bits,
embedding_quant_bits=8,
special_model='FALCON_RW_1B',
backend='cpu',
)
actions = loader.load_to_actions()
# There are 16 layers in the model, but qkv weight and bias would be
# decomposed to q, k, v tensors, so there would be 20 quantization actions.
self.assertEqual(sum(len(action) for action in actions), 20)
if __name__ == '__main__':
absltest.main()
@@ -0,0 +1,516 @@
# Copyright 2024 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for quantizing tensors.
Note that this is a reduced fork version of the praxis libraries to provide a
self-contained library for packaging.
"""
from typing import Any, List, Optional, Sequence, Tuple, Union
import jax
from jax import lax
from jax import numpy as jnp
import numpy as np
JTensor = jax.Array
_UINT4_ZP = 8 # Default zero point for unsigned 4-bit.
def _get_scan_range() -> np.ndarray:
# Produce candidate scan values.
return np.linspace(1.0, 0.5, num=11)
def _get_mean_error(bound, t, min_value, max_value, p_value):
scale = bound / max_value
candidate = jnp.divide(t, scale)
candidate = jnp.clip(jnp.round(candidate), min_value, max_value)
candidate = jnp.multiply(candidate, scale)
pmean_error = jnp.mean(jnp.abs(jnp.subtract(candidate, t)) ** p_value)
return pmean_error
def _get_best_bound_per_tensor(
t: JTensor,
bound: JTensor,
min_value: float,
max_value: float,
p_value: float = 1.0,
) -> JTensor:
"""Scan around [0.5, 1] * hard max value to get bound value for whole tensor.
This does a scan to get bound value(s) that minimize mean absolute error (MAE)
between original tensor 't' and quantized tensor. It's (almost) equivalent to
maximizing entropy.
Args:
t: The input float tensor.
bound: The hard max value for tensor 't'. It has the same length as shape.
min_value: Minimal value for the quantization bound.
max_value: Maximal value for the quantization bound.
p_value: Exponent of the p-mean error metric. Default to 1.0 which is MAE.
Returns:
The best bound values for 't', that minimize p-mean error.
"""
def _quant(scaling_factors):
return _get_mean_error(
bound * scaling_factors, t, min_value, max_value, p_value
)
scaling_factors = _get_scan_range()
diffs = jax.vmap(_quant)(scaling_factors)
best_scaling = scaling_factors[jnp.argmin(diffs)].astype(bound.dtype)
return bound * best_scaling
def _quantrow(
vec: JTensor,
bound: JTensor,
min_value: float,
max_value: float,
p_value: float,
factors: np.ndarray,
) -> JTensor:
"""Get best rescaling factor from a list of factors applied a channel.
Args:
vec: The vector in a channel.
bound: The hard bound (max(abs(vec))) of the vector.
min_value: The target min value.
max_value: The target max value.
p_value: Exponent of the p-mean error metric.
factors: The values to be applied on top of bound.
Returns:
adjusted bound value out of the list of factors applied to bound.
"""
def _quant(bounds):
return _get_mean_error(bounds, vec, min_value, max_value, p_value)
diffs = jax.vmap(_quant)(bound * factors)
best_scaling = factors[jnp.argmin(diffs)]
return bound * best_scaling
def _get_best_bound_per_channel(
t: JTensor,
bound: JTensor,
min_value: float,
max_value: float,
p_value: float = 1.0,
) -> JTensor:
"""Scan around [0.5, 1] * hard max value to get bound value for each channel.
This does a scan to get bound value(s) that minimize mean absolute error (MAE)
between original tensor 't' and quantized tensor. It's (almost) equivalent to
maximizing entropy.
Args:
t: The input float tensor.
bound: The hard max value for tensor 't'. It has the same length as shape.
min_value: Minimal value for the quantization bound.
max_value: Maximal value for the quantization bound.
p_value: Exponent of the p-mean error metric. Default to 1.0 which is MAE.
Returns:
The best bound values for 't', that minimize p-mean error.
"""
assert len(t.shape) == 2
assert len(bound.shape) == 2
assert t.shape[1] == bound.shape[1]
assert bound.shape[0] == 1
scans = _get_scan_range()
def _quant(tensor, bound, min_value, max_value, p_value, factors):
ret = np.zeros(bound.shape)
for i in range(len(tensor)):
best = _quantrow(
tensor[i], bound[i], min_value, max_value, p_value, factors
)
ret[i] = best
return ret
t = t.transpose()
t_split = list(t)
res = _quant(t_split, bound[0, :], min_value, max_value, p_value, scans)
res = res.reshape(bound.shape)
return res
def get_best_bound(
t: JTensor,
bound: JTensor,
min_value: float,
max_value: float,
p_value: float = 1.0,
per_channel: bool = False,
) -> JTensor:
"""Scan mutliple factors on max value to get best bound value.
This does a scan to get bound value(s) that minimize mean absolute error (MAE)
between original tensor 't' and quantized tensor. It's (almost) equivalent to
maximizing entropy.
Args:
t: The input float tensor.
bound: The hard max value for tensor 't'. It has the same length as shape.
min_value: Minimal value for the quantization bound.
max_value: Maximal value for the quantization bound.
p_value: Exponent of the p-mean error metric. Default to 1.0 which is MAE.
per_channel: if get best bound for entire tensor or per channel.
Returns:
The best bound values for 't', that minimize p-mean error.
"""
if per_channel:
return _get_best_bound_per_channel(t, bound, min_value, max_value, p_value)
else:
return _get_best_bound_per_tensor(t, bound, min_value, max_value, p_value)
def get_min_max(
bits: int = 8,
unsigned: bool = False,
use_fp: bool = False,
) -> Tuple[float, float]:
"""Gets the min/max range for a given number of bits.
Args:
bits: Target number of bits for quantization.
unsigned: If True compute min and max for unsigned number, else for signed.
use_fp: in floating point.
Returns:
min/max values for the provide number of bits.
"""
if use_fp:
# TODO: support other fp types.
return -448.0, 448.0
# Calculation instead of jax.iinfo is used to support bits beside 4 and 8.
if unsigned:
# For unsigned 8 bits precision it is [0, 255]
return 0, 2**bits - 1
else:
# For signed 8 bits precision it is [-128, 127]
return -1 * 2 ** (bits - 1), 2 ** (bits - 1) - 1
def pass_through(x: JTensor, fn: Any) -> JTensor:
# Create an exactly-zero expression with Sterbenz lemma that has an
# exactly-one gradient.
return x - jax.lax.stop_gradient(x) + jax.lax.stop_gradient(fn(x))
def reduce_precision(
t: JTensor,
contract_dims: Optional[Sequence[int]],
need_gradient: bool = False,
bits: int = 8,
optimization_on_bound: bool = False,
p_value: float = 1.0,
percentile: float = 1.0,
use_symmetric: bool = True,
use_fp: bool = False,
add_scale_eps: bool = False,
per_channel: bool = False,
random_rounding: bool = False,
key: Optional[jax.Array] = None,
) -> Tuple[JTensor, JTensor, Optional[JTensor]]:
"""Reduce the precision of a tensor.
Generic for all tensors.
Args:
t: Input tensor.
contract_dims: Speficies contracting dimesnions of the input tensor.
need_gradient: If gradient is needed out of this function.
bits: Target number of bits.
optimization_on_bound: If MAE bound optimizer is used.
p_value: Exponent of the p-mean error metric. Default to 1.0 which is MAE.
percentile: Percentile Factor to apply on the min/max range. Setting this to
other than 1.0 disables optimization_on_bound.
use_symmetric: If the input tensor is quantized symmetrically.
use_fp: Use floating point.
add_scale_eps: Add eps value or replace zero value by 1 to avoid division by
zero.
per_channel: use per-channel clipping optimization.
random_rounding: round with uniform random.
key: rng key for rounding.
Returns:
A tuple of quantized tensor, quantization scale
and quantization zero point (optional).
"""
min_value, max_value = get_min_max(bits, use_fp=use_fp)
if use_symmetric:
bound = jnp.max(jnp.abs(t), axis=contract_dims, keepdims=True)
scale_bound = max_value
else:
t_max = jnp.max(t, axis=contract_dims, keepdims=True)
t_min = jnp.min(t, axis=contract_dims, keepdims=True)
bound = t_max - t_min
scale_bound = max_value - min_value
if percentile < 1.0:
bound = jnp.multiply(bound, percentile)
elif optimization_on_bound:
bound = get_best_bound(
t, bound, min_value, max_value, p_value, per_channel=per_channel
)
scale = bound / scale_bound
if add_scale_eps:
# Add epsilon to avoid divide-by-zero.
scale = scale + jnp.finfo(t.dtype).eps
else:
scale = jnp.where(scale == 0.0, 1.0, scale)
if use_symmetric:
zp = None
t = jnp.divide(t, scale)
else:
zp = min_value - t_min / scale
t = jnp.divide(t, scale) + zp
zp = jnp.multiply(scale, zp)
if use_fp:
# No need to round.
t = jnp.clip(t, min_value, max_value).astype(jnp.float8_e4m3fn)
# TODO: refactor to remove this logic.
t = jax.lax.bitcast_convert_type(t, new_dtype=jnp.int8)
else:
if need_gradient:
t = pass_through(t, jnp.round)
t = jnp.clip(t, min_value, max_value)
else:
if random_rounding:
t = t + jax.random.uniform(
key=key, shape=t.shape, minval=-0.5, maxval=0.5
)
t = jnp.round(t)
container_dtype = (
jnp.int8 if bits <= 8 else jnp.int16 if bits <= 16 else jnp.int32
)
t = jnp.clip(t, min_value, max_value).astype(container_dtype)
return t, scale, zp
def quantize_tensor(
var: np.ndarray,
axis: List[int],
factor: float = 1.0,
sym: bool = True,
number_bits: int = 8,
use_fp: bool = False,
add_scale_eps: bool = False,
optimization_on_bound: bool = False,
p_value: float = 1.0,
per_channel: bool = False,
block_size: int = 0,
) -> Union[
Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, np.ndarray]
]:
"""Quantize a tensor.
Args:
var: The variable to be quantized.
axis: The axis along which variable will be quantized.
factor: The clipping factor.
sym: Symmetric or asymmetric quantize the variable.
number_bits: Number of bits for quantized value.
use_fp: do fp with number of bits (i.e. fp8)
add_scale_eps: add epsilon to scale to avoid division by zero, else it will
replace zero scale by 1.
optimization_on_bound: If p-mean bound optimizer is used.
p_value: Exponent of the p-mean error metric. Default to 1.0 which is MAE.
per_channel: use per-channel clipping optimization.
block_size: block size for sub-channel quantization. Defaults to 0, which
means off.
Returns:
Quantized tensors, along with scales and zero point.
"""
# TODO: support jnp.float8_e5m2
assert number_bits == 8 or number_bits == 4
jnp_var = jnp.asarray(var)
# When using sub-channel, the contracting dim is split into a sub-channel
# dim followed by the block dim. Therefore the contracting dim
# (quantize_axis) should increment by one, and the corresponding pack_dim
# should also increment by one.
if block_size > 0:
shape = list(jnp_var.shape)
assert len(axis) == 1, 'Only support 1D sub-channel quantization'
sub_channels, rem = divmod(shape[axis[0]], block_size)
assert rem == 0
shape.insert(axis[0], sub_channels)
axis[0] += 1
shape[axis[0]] = block_size
jnp_var = jnp.reshape(jnp_var, shape)
qvar, scale, zp = reduce_precision(
jnp_var,
contract_dims=axis,
need_gradient=False,
bits=number_bits,
optimization_on_bound=optimization_on_bound,
percentile=factor,
use_symmetric=sym,
use_fp=use_fp,
add_scale_eps=add_scale_eps,
p_value=p_value,
per_channel=per_channel,
)
if sym:
return np.array(qvar), np.array(jnp.squeeze(scale, axis=axis)) # pytype: disable=wrong-arg-types # jnp-type
else:
return (
np.array(qvar),
# CAVEAT: the following squeezes should squeeze along the quantization
# axis only.
np.array(jnp.squeeze(scale)),
np.array(jnp.squeeze(zp)),
)
def pack_4bit(
x: np.ndarray, pack_dim: int, packed_dtype: jnp.dtype = jnp.int32
) -> np.ndarray:
"""Pack int8 or uint8 tensor where its values are actually int4 or uint4, to int32 or int8 nibble format along pack_dim.
Args:
x: Original int8 or uint8 tensor to pack.
pack_dim: Dimension to pack along. x.shape[pack_dim] must be divisible by 8,
when packed_dtype is int32 and divisible by 2 when target_type is int8.
Also pack_dim must be < x.ndim - 1.
packed_dtype: Target type to pack to, int32 or int8.
Returns:
int32 or int8 packed tensor where the pack_dim size is dividened by 8
from the original tensor x.
"""
x = jnp.asarray(x)
if packed_dtype == jnp.int8 and x.dtype == jnp.uint8:
# It doesn't make sense to pack uint8 numbers into int4 as we'll
# the range overlap between uint8 and int4 is [0..7].
raise ValueError(
'only int8 input dtype is supported when packing into int8. '
f'Given {x.dtype}'
)
if x.dtype != jnp.int8 and x.dtype != jnp.uint8:
raise ValueError(
f'input dtype must be either int8 or uint8. Given {x.dtype}'
)
if pack_dim >= x.ndim - 1:
raise ValueError(
f'pack_dim must be < input ndim - 1. input shape {x.shape} and pack_dim'
f' {pack_dim}'
)
if packed_dtype != jnp.int32 and packed_dtype != jnp.int8:
raise ValueError(
f'packed_dtype must be either int32 or int8. Given {packed_dtype}'
)
if packed_dtype == jnp.int32 and x.shape[pack_dim] % 8 != 0:
raise ValueError(
'input shape[pack_dim] must be divisible by 8 when target_type '
f'is int32. Given shape {x.shape}'
)
if packed_dtype == jnp.int8 and x.shape[pack_dim] % 2 != 0:
raise ValueError(
'input shape[pack_dim] must be divisible by 2 when target_type '
f'is int8. Given shape {x.shape}'
)
int4s_per_packed_type = 8 if packed_dtype == jnp.int32 else 2
rep_shape = list(x.shape)
rep_shape.insert(pack_dim + 1, int4s_per_packed_type)
rep_shape[pack_dim] //= int4s_per_packed_type
shifts = lax.broadcasted_iota(packed_dtype, rep_shape, pack_dim + 1)
shifts <<= 2
# Promote x to packed_dtype
x = x & jnp.array(0x0F, packed_dtype)
x = lax.reshape(x, rep_shape)
x = x << shifts
x = lax.reduce(x, jnp.array(0x0, packed_dtype), lax.add, [pack_dim + 1])
return np.asarray(x)
def update_to_uint4(
qx: np.ndarray, scale: np.ndarray, zp: Optional[np.ndarray] = None
):
"""Updates the quantized weights from int4 to uint4.
This is a conversion function designed for XNNPack as it expects the 4-bit
quantized weight to be represented differently from the original Pax setting.
Specifically, the differences are:
1) The dynamic range of weight values: int4 (Pax) vs. uint4 (XNNPack).
2) The dynamic range of zero-point: float (Pax) vs. uint4 (XNNPack).
3) The number of zero-point: per-channel (Pax) vs. per-tensor (XNNPack).
Args:
qx: np.array of shape [..., channel], which is the quantized weight values
from Pax in the shape of. The values are in the dynamic range of int4 but
are hosted as int8 type. Note that if the first dimension is 3, it means
the qkv matrices are concatenated together and should be treated
differently.
scale: np.array of shape [1(3), channel] as np.float type, which are the
scaling factors for dequantization per channel.
zp: (optional) np.array of shape [1 (or 3), channel] as np.float type, which
are the zero points for dequantization per channel.
Returns:
A tuple (qx, scale, zp):
qx: The updated np.array of shape [..., channel] as np.int8 type with
updated dynamic range as uint4 (with 8 as the default zero points).
scale: Same as the input scale.
zp: (optional) np.array of shape [1 (or 3)] as np.int8 type with the
updated zero point values in the dynamic range as uint4.
"""
if qx.dtype != np.int8 or ('float' not in str(scale.dtype)):
raise ValueError(
'Unexpected dtype qx:' + str(qx.dtype) + ' scale:' + str(scale.dtype)
)
scale = scale.astype(np.float32)
def get_new_zp(old_zp):
new_zp = old_zp / (scale + np.finfo(np.float32).eps)
per_tensor_zp = np.mean(new_zp)
per_tensor_zp = per_tensor_zp.astype(np.int8) + _UINT4_ZP
return per_tensor_zp
if zp is not None:
if qx.shape[0] == 3:
per_tensor_zp = np.stack([get_new_zp(szp) for szp in zp], axis=0)
else:
per_tensor_zp = get_new_zp(zp)
else:
per_tensor_zp = (
_UINT4_ZP * np.ones(shape=(3)) if qx.shape[0] == 3 else _UINT4_ZP
)
qx = qx + _UINT4_ZP
return qx, scale, np.array(per_tensor_zp, dtype=np.int32)
@@ -0,0 +1,259 @@
# Copyright 2024 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for quantization_util."""
from absl.testing import absltest
import jax
from jax import numpy as jnp
import numpy as np
from mediapipe.tasks.python.genai.converter import quantization_util
_dtype = lambda x: getattr(x, 'dtype', None) or np.asarray(x).dtype
class TestCase(absltest.TestCase):
def assertAllClose(
self, x, y, check_dtypes=True, rtol=1e-5, atol=1e-5, **kwargs
):
"""Wrapper for np.testing.assert_allclose()."""
x = np.asarray(x)
y = np.asarray(y)
if check_dtypes:
self.assertDtypesMatch(x, y)
x = x.astype(np.float32) if x.dtype == jnp.bfloat16 else x
y = y.astype(np.float32) if y.dtype == jnp.bfloat16 else y
np.testing.assert_allclose(x, y, rtol=rtol, atol=atol, **kwargs)
def assertDtypesMatch(self, x, y):
self.assertEqual(
jax.dtypes.canonicalize_dtype(_dtype(x)),
jax.dtypes.canonicalize_dtype(_dtype(y)),
)
class Quantize8BTest(TestCase):
def test_quantize_symmetric(self):
inputs = np.array([[1.2, 3.1, 5.5, 2.9], [0.2, -1.5, 3.3, 4.0]])
qx, scale = quantization_util.quantize_tensor(inputs, axis=[1])
self.assertAllClose(
qx, np.array([[28, 72, 127, 67], [6, -48, 105, 127]], dtype=np.int8)
)
self.assertAllClose(
scale, np.array([0.04330709, 0.03149606], dtype=np.float32)
)
def test_quantize_symmetric_with_dimension_size_one_unquantized(self):
# inputs shape: (2, 1, 4), quantization axis 2.
inputs = np.array([[[1.2, 3.1, 5.5, 2.9]], [[0.2, -1.5, 3.3, 4.0]]])
qx, scale = quantization_util.quantize_tensor(inputs, axis=[2])
self.assertAllClose(
qx, np.array([[[28, 72, 127, 67]], [[6, -48, 105, 127]]], dtype=np.int8)
)
# expected scale shape: (2, 1)
self.assertAllClose(
scale, np.array([[0.04330709], [0.03149606]], dtype=np.float32)
)
def test_quantize_asymmetric(self):
inputs = np.array([[1.2, 3.1, 5.5, 2.9], [0.2, -1.5, 3.3, 4.0]])
qx, scale, zp = quantization_util.quantize_tensor(
inputs, axis=[1], sym=False
)
self.assertAllClose(
qx,
np.array([[-128, -15, 127, -27], [-49, -128, 95, 127]], dtype=np.int8),
)
self.assertAllClose(scale, np.array([0.016863, 0.021569], dtype=np.float32))
self.assertAllClose(zp, np.array([-3.358431, -1.260784], dtype=np.float32))
class Quantize8BFPTest(TestCase):
def test_quantize_symmetric(self):
inputs = np.array([[1.0, 2.0, 5.5, 2.9], [0.02, -0.01, 3.3, 4.0]])
qx, scale = quantization_util.quantize_tensor(inputs, axis=[1], use_fp=True)
self.assertAllClose(
qx,
np.array([[106, 114, 126, 119], [65, -71, 124, 126]], dtype=np.int8),
)
self.assertAllClose(
scale, np.array([0.01227679, 0.00892857], dtype=np.float32)
)
def test_quantize_symmetric_with_dimension_size_one_unquantized(self):
# inputs shape: (2, 1, 4), quantization axis 2.
inputs = np.array([[[1.0, 2.0, 5.5, 2.9]], [[0.02, -0.01, 3.3, 4.0]]])
qx, scale = quantization_util.quantize_tensor(inputs, axis=[2], use_fp=True)
self.assertAllClose(
qx,
np.array(
[[[106, 114, 126, 119]], [[65, -71, 124, 126]]], dtype=np.int8
),
)
# expected scale shape: (2, 1)
self.assertAllClose(
scale, np.array([[0.01227679], [0.00892857]], dtype=np.float32)
)
def test_quantize_asymmetric(self):
inputs = np.array([[-1.0, -2.0, -2.01, 2.9], [0.02, -0.15, 3.3, 4.0]])
qx, scale, zp = quantization_util.quantize_tensor(
inputs, axis=[1], sym=False, use_fp=True
)
self.assertAllClose(
qx,
np.array([[-8, -2, -2, 126], [-3, -2, 121, 126]], dtype=np.int8),
)
self.assertAllClose(
scale, np.array([0.00547991, 0.0046317], dtype=np.float32)
)
self.assertAllClose(
zp, np.array([-0.4449999, -1.9250002], dtype=np.float32)
)
def test_quantize_add_scale_eps(self):
inputs = np.array([[0.0, 0.0, 0.0, 0.0], [-4.0, -4.0, -4.0, -4.0]])
_, scale, _ = quantization_util.quantize_tensor(
inputs, axis=[1], sym=False, use_fp=True, add_scale_eps=True
)
self.assertAllClose(
scale, np.array([np.finfo(np.float32).eps, np.finfo(np.float32).eps])
)
_, scale, _ = quantization_util.quantize_tensor(
inputs, axis=[1], sym=False, use_fp=True, add_scale_eps=False
)
self.assertAllClose(scale, np.array([1.0, 1.0]))
class Quantize4BTest(TestCase):
def test_quantize_symmetric(self):
inputs = np.array([[1.2, 3.1, 5.5, 2.9], [0.2, -1.5, 3.3, 4.0]])
qx, scale = quantization_util.quantize_tensor(
inputs, axis=[1], number_bits=4
)
self.assertAllClose(
qx, np.array([[2, 4, 7, 4], [0, -3, 6, 7]], dtype=np.int8)
)
self.assertAllClose(
scale, np.array([0.78571427, 0.5714286], dtype=np.float32)
)
def test_quantize_symmetric_with_dimension_size_one_unquantized(self):
# inputs shape: (2, 1, 4), quantization axis 2.
inputs = np.array([[[1.2, 3.1, 5.5, 2.9]], [[0.2, -1.5, 3.3, 4.0]]])
qx, scale = quantization_util.quantize_tensor(
inputs, axis=[2], number_bits=4
)
self.assertAllClose(
qx, np.array([[[2, 4, 7, 4]], [[0, -3, 6, 7]]], dtype=np.int8)
)
# expected scale shape: (2, 1)
self.assertAllClose(
scale, np.array([[0.78571427], [0.5714286]], dtype=np.float32)
)
def test_quantize_asymmetric(self):
inputs = np.array([[1.2, 3.1, 5.5, 2.9], [0.2, -1.5, 3.3, 4.0]])
qx, scale, zp = quantization_util.quantize_tensor(
inputs, axis=[1], sym=False, number_bits=4
)
self.assertAllClose(
qx,
np.array([[-8, -1, 7, -2], [-3, -8, 5, 7]], dtype=np.int8),
)
self.assertAllClose(
scale, np.array([0.2866667, 0.36666667], dtype=np.float32)
)
self.assertAllClose(
zp, np.array([-3.4933336, -1.4333334], dtype=np.float32)
)
class QuantizationUtilTest(TestCase):
def test_update_to_uint4_sym(self):
inputs = np.array([[1.2, 3.1, -5.5, 2.9], [0.2, -1.5, 3.3, 4.0]])
qx, scale = quantization_util.quantize_tensor(
inputs, axis=[1], sym=True, number_bits=4
)
dequant_from_int4 = qx * np.expand_dims(scale, -1)
qx_n, scale_n, zp_n = quantization_util.update_to_uint4(qx, scale)
self.assertEmpty(zp_n.shape) # A scalar numpy array.
dequant_from_uint4 = np.expand_dims(scale_n, -1) * (qx_n - zp_n)
np.testing.assert_allclose(dequant_from_int4, dequant_from_uint4)
def test_update_to_uint4_sym_combined(self):
inputs = np.array(
[[-1.2, 3.5, -6.2, 1.7], [1.2, 3.1, -5.5, 2.9], [0.2, -1.5, 3.3, 4.0]]
)
qx, scale = quantization_util.quantize_tensor(
inputs, axis=[1], sym=True, number_bits=4
)
dequant_from_int4 = qx * np.expand_dims(scale, -1)
qx_n, scale_n, zp_n = quantization_util.update_to_uint4(qx, scale)
self.assertEqual(zp_n.shape[0], 3)
dequant_from_uint4 = np.expand_dims(scale_n, -1) * (
qx_n - np.expand_dims(zp_n, -1)
)
np.testing.assert_allclose(dequant_from_int4, dequant_from_uint4)
def test_update_to_uint4_asym(self):
inputs = np.array([[1.0, 8.0, -3.0, 2.0], [-3.0, 2.0, 1.0, 8.0]])
qx, scale, zp = quantization_util.quantize_tensor(
inputs, axis=[1], sym=False, number_bits=4
)
qx_n, scale_n, zp_n = quantization_util.update_to_uint4(qx, scale, zp)
expected_dequant = np.array([
[0.0, 7.333333, -3.666667, 1.466667],
[-3.666667, 1.466667, 0.0, 7.333333],
])
dequant_from_uint4 = np.expand_dims(scale_n, -1) * (qx_n - zp_n)
np.testing.assert_allclose(dequant_from_uint4, expected_dequant, rtol=1e-05)
def test_update_to_uint4_asym_combined(self):
inputs = np.array(
[[1.0, 8.0, -3.0, 2.0], [-3.0, 2.0, 1.0, 8.0], [2.0, 1.0, 8.0, -3.0]]
)
qx, scale, zp = quantization_util.quantize_tensor(
inputs, axis=[1], sym=False, number_bits=4
)
qx_n, scale_n, zp_n = quantization_util.update_to_uint4(qx, scale, zp)
self.assertEqual(zp_n.shape[0], 3)
expected_dequant = np.array([
[0.0, 7.333333, -3.666667, 1.466667],
[-3.666667, 1.466667, 0.0, 7.333333],
[1.466667, 0.0, 7.333333, -3.666667],
])
dequant_from_uint4 = np.expand_dims(scale_n, -1) * (
qx_n - np.expand_dims(zp_n, -1)
)
np.testing.assert_allclose(dequant_from_uint4, expected_dequant, rtol=1e-05)
if __name__ == '__main__':
absltest.main()
@@ -0,0 +1,554 @@
# Copyright 2024 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CkptLoader implementation for loading the Safetensors."""
import array
from typing import Iterator
import enum
import glob
import json
import os
from typing import List, Optional
import numpy as np
import torch
from mediapipe.tasks.python.genai.converter import converter_base
DTYPE_MAP = {
"F16": torch.float16,
"BF16": torch.bfloat16,
"F32": torch.float32,
}
class _SafetensorsShardReader:
"""Reads a single safetensors shard."""
_HEAD_BYTES = 8
def __init__(self, shard_path: str):
self._shard_path = shard_path
if not os.path.exists(self._shard_path):
raise ValueError(f"{self._shard_path} does not exists.")
with open(self._shard_path, "rb") as f:
head_bytes = f.read(self._HEAD_BYTES)
metadata_bytes_num = np.frombuffer(head_bytes, dtype=np.uint64)[0]
metadata_bytes = f.read(metadata_bytes_num)
self.layers_info = json.loads(metadata_bytes)
self.metadata_bytes_num = metadata_bytes_num
def read_tensor_as_numpy(self, tensor_name) -> np.ndarray:
"""Reads a tensor from the model file as a numpy array with np.float32 type."""
tensor_info = self.layers_info[tensor_name]
with open(self._shard_path, "rb") as f:
shape = tensor_info["shape"]
dtype = tensor_info["dtype"]
if dtype not in DTYPE_MAP:
raise ValueError(f"{dtype} is not supported.")
data_offsets = tensor_info["data_offsets"]
f.seek(int(self._HEAD_BYTES + self.metadata_bytes_num + data_offsets[0]))
tensor_bytes = f.read(data_offsets[1] - data_offsets[0])
raw_tensor = torch.frombuffer(
array.array("b", tensor_bytes), dtype=DTYPE_MAP[dtype]
).reshape(shape)
return raw_tensor.float().t().contiguous().numpy()
def get_tensor_names(self) -> List[str]:
names = list(self.layers_info.keys())
if "__metadata__" in names:
names.remove("__metadata__")
return names
class _SafetensorsReader:
"""Reads all the safetensors shards."""
def __init__(self, ckpt_path: str):
shards = []
if os.path.isdir(ckpt_path):
# Read all safetensors files within checkpoint
for shard_path in glob.glob(os.path.join(ckpt_path, "*.safetensors")):
shards.append(_SafetensorsShardReader(shard_path))
else:
# Assume the ckpt_path is a file or a file pattern to match.
for shard_path in glob.glob(ckpt_path):
shards.append(_SafetensorsShardReader(shard_path))
assert shards is not None
self._ckpt_path = ckpt_path
self._tensors_map = {}
for shard in shards:
tensor_names = shard.get_tensor_names()
for tensor_name in tensor_names:
if tensor_name in self._tensors_map:
raise ValueError(f"Duplicate tensor name: {tensor_name}")
self._tensors_map[tensor_name] = shard
def get_tensor_names(self) -> List[str]:
return list(self._tensors_map.keys())
def read_tensor_as_numpy(self, tensor_name: str) -> np.ndarray:
return self._tensors_map[tensor_name].read_tensor_as_numpy(tensor_name)
class LayerType(enum.Enum):
"""Enum for layer type."""
NONE = 0
ATTENTION = 1 # Layer is part of the attention module.
FEEDFORWARD = 2 # Layer is part of the feedforward module in the Transformer.
EMBEDDING = 3 # Layer is the embedding lookup or final projection layer.
LAYER_NORM = (
4 # Layer is layer normalization before and after attention layer.
)
LORA = 5 # Layer is LoRA weights augmented on the base model layers.
@classmethod
def get_layer_type(cls, layer_name: str):
"""Gets the layer type of the given layer name."""
ffn_layers = [
"mlp",
]
attn_layers = [
"self_attn",
]
emb_layers = [
"embed_tokens",
"lm_head",
]
layer_norms = [
"input_layernorm",
"post_attention_layernorm",
"final_layernorm",
"model.norm.weight",
]
lora_layers = ["lora"]
if any(sub_name in layer_name for sub_name in lora_layers):
return LayerType.LORA
if any(sub_name in layer_name for sub_name in attn_layers):
return LayerType.ATTENTION
if any(sub_name in layer_name for sub_name in ffn_layers):
return LayerType.FEEDFORWARD
if any(sub_name in layer_name for sub_name in emb_layers):
return LayerType.EMBEDDING
if any(sub_name in layer_name for sub_name in layer_norms):
return LayerType.LAYER_NORM
else:
return LayerType.NONE
class StablelmMapper(converter_base.LayerActionMapperBase):
"""LayerActionMapper for handling the StableLM model."""
def __init__(
self,
is_symmetric: bool,
attention_quant_bits: int,
feedforward_quant_bits: int,
embedding_quant_bits: int,
backend: str,
reader: _SafetensorsReader,
):
super().__init__(
is_symmetric=is_symmetric,
attention_quant_bits=attention_quant_bits,
feedforward_quant_bits=feedforward_quant_bits,
embedding_quant_bits=embedding_quant_bits,
backend=backend,
)
self._reader = reader
def map_to_actions(
self, layer_name: str
) -> Optional[List[converter_base.QuantizationAction]]:
"""Map the given layer name to actions."""
tensor_value = self._reader.read_tensor_as_numpy(layer_name)
quantize_axis = None
quantize_bits = None
layer_type = LayerType.get_layer_type(layer_name)
if layer_type != LayerType.LAYER_NORM and layer_name.endswith(".weight"):
quantize_axis = [0]
if layer_type == LayerType.FEEDFORWARD:
quantize_bits = self._feedforward_quant_bits
elif layer_type == LayerType.ATTENTION:
quantize_bits = self._attention_quant_bits
if self._backend == "cpu" and ".o_proj." in layer_name:
tensor_value = np.transpose(tensor_value)
quantize_axis = [1]
elif layer_type == LayerType.EMBEDDING:
quantize_bits = self._embedding_quant_bits
if self._backend == "cpu" and ".embed_tokens." in layer_name:
tensor_value = np.transpose(tensor_value)
quantize_axis = [1]
target_name = self.update_target_name(layer_name)
actions = [
converter_base.QuantizationAction(
tensor_name=layer_name,
tensor_value=tensor_value,
target_name=target_name,
quantize_axis=quantize_axis,
quantize_bits=quantize_bits,
pack_dim=0,
)
]
return actions
def update_target_name(self, target_name: str) -> str:
"""Updates the target name to match the tensor name convention."""
target_name = target_name.replace(
"model.layers.", "params.lm.transformer.x_layers_"
)
target_name = target_name.replace("mlp.up_proj", "ff_layer.ffn_layer1")
target_name = target_name.replace("mlp.down_proj", "ff_layer.ffn_layer2")
target_name = target_name.replace(
"mlp.gate_proj", "ff_layer.ffn_layer1_gate"
)
target_name = target_name.replace("input_layernorm", "pre_layer_norm")
target_name = target_name.replace(
"pre_layer_norm.weight", "pre_layer_norm.scale"
)
if self._backend == "cpu":
target_name = target_name.replace(
"post_attention_layernorm", "ff_layer.pre_layer_norm"
)
target_name = target_name.replace(
"ff_layer.pre_layer_norm.weight", "ff_layer.pre_layer_norm.scale"
)
else:
target_name = target_name.replace(
"post_attention_layernorm", "post_layer_norm"
)
target_name = target_name.replace(
"post_layer_norm.weight", "post_layer_norm.scale"
)
target_name = target_name.replace("self_attn.q_proj", "self_attention.q")
target_name = target_name.replace("self_attn.k_proj", "self_attention.k")
target_name = target_name.replace("self_attn.v_proj", "self_attention.v")
target_name = target_name.replace("self_attn.o_proj", "self_attention.post")
target_name = target_name.replace(
"model.embed_tokens", "params.lm.token_embedding"
)
target_name = target_name.replace("model.norm", "params.lm.final_ln")
target_name = target_name.replace("final_ln.weight", "final_ln.scale")
target_name = target_name.replace("lm_head", "params.lm.softmax.logits_ffn")
target_name = target_name.replace(".weight", ".w")
return target_name
class PhiMapper(converter_base.LayerActionMapperBase):
"""LayerActionMapper for handling the Phi model."""
def __init__(
self,
is_symmetric: bool,
attention_quant_bits: int,
feedforward_quant_bits: int,
embedding_quant_bits: int,
backend: str,
reader: _SafetensorsReader,
):
super().__init__(
is_symmetric=is_symmetric,
attention_quant_bits=attention_quant_bits,
feedforward_quant_bits=feedforward_quant_bits,
embedding_quant_bits=embedding_quant_bits,
backend=backend,
)
self._reader = reader
def map_to_actions(
self, layer_name: str
) -> Optional[List[converter_base.QuantizationAction]]:
"""Map the given layer name to actions."""
tensor_value = self._reader.read_tensor_as_numpy(layer_name)
quantize_axis = None
quantize_bits = None
layer_type = LayerType.get_layer_type(layer_name)
if (
layer_type != LayerType.LAYER_NORM
and layer_name.endswith(".weight")
and layer_type != LayerType.LORA
):
quantize_axis = [0]
if layer_type == LayerType.FEEDFORWARD:
quantize_bits = self._feedforward_quant_bits
elif layer_type == LayerType.ATTENTION:
quantize_bits = self._attention_quant_bits
if self._backend == "cpu" and ".dense." in layer_name:
tensor_value = np.transpose(tensor_value)
quantize_axis = [1]
elif layer_type == LayerType.EMBEDDING:
quantize_bits = self._embedding_quant_bits
if self._backend == "cpu" and ".embed_tokens." in layer_name:
tensor_value = np.transpose(tensor_value)
quantize_axis = [1]
target_name = self.update_target_name(layer_name)
actions = [
converter_base.QuantizationAction(
tensor_name=layer_name,
tensor_value=tensor_value,
target_name=target_name,
quantize_axis=quantize_axis,
quantize_bits=quantize_bits,
pack_dim=0,
)
]
return actions
def update_target_name(self, target_name: str) -> str:
"""Updates the target name to match the tensor name convention."""
target_name = target_name.replace("base_model.model.", "")
target_name = target_name.replace(
"model.layers.", "params.lm.transformer.x_layers_"
)
layer_type = LayerType.get_layer_type(target_name)
if layer_type == LayerType.FEEDFORWARD:
target_name = target_name.replace(".weight", ".linear.w")
target_name = target_name.replace(".bias", ".bias.b")
target_name = target_name.replace("mlp.fc1", "ff_layer.ffn_layer1")
target_name = target_name.replace("mlp.fc2", "ff_layer.ffn_layer2")
elif layer_type == LayerType.ATTENTION:
target_name = target_name.replace(".weight", ".linear.w")
target_name = target_name.replace(".bias", ".bias.b")
target_name = target_name.replace("self_attn.q_proj", "self_attention.q")
target_name = target_name.replace("self_attn.k_proj", "self_attention.k")
target_name = target_name.replace("self_attn.v_proj", "self_attention.v")
target_name = target_name.replace(
"self_attn.dense", "self_attention.post"
)
elif layer_type == LayerType.EMBEDDING:
target_name = target_name.replace(
"model.embed_tokens", "params.lm.token_embedding"
)
target_name = target_name.replace(
"lm_head", "params.lm.softmax.logits_ffn"
)
target_name = target_name.replace(
"logits_ffn.weight", "logits_ffn.linear.w"
)
target_name = target_name.replace("logits_ffn.bias", "logits_ffn.bias.b")
elif layer_type == LayerType.LAYER_NORM:
target_name = target_name.replace("input_layernorm", "pre_layer_norm")
target_name = target_name.replace(
"pre_layer_norm.weight", "pre_layer_norm.scale"
)
target_name = target_name.replace(
"model.final_layernorm", "params.lm.final_ln"
)
target_name = target_name.replace("final_ln.weight", "final_ln.scale")
target_name = target_name.replace(".weight", ".w")
# For LoRA weights
if "post" in target_name:
target_name = target_name.replace("lora_A.linear.w", "w_prime_right")
target_name = target_name.replace("lora_B.linear.w", "w_prime_left")
else:
target_name = target_name.replace("lora_A.linear.w", "w_prime_left")
target_name = target_name.replace("lora_B.linear.w", "w_prime_right")
return target_name
class GemmaMapper(converter_base.LayerActionMapperBase):
"""LayerActionMapper for handling the StableLM model."""
def __init__(
self,
is_symmetric: bool,
attention_quant_bits: int,
feedforward_quant_bits: int,
embedding_quant_bits: int,
backend: str,
reader: _SafetensorsReader,
):
super().__init__(
is_symmetric=is_symmetric,
attention_quant_bits=attention_quant_bits,
feedforward_quant_bits=feedforward_quant_bits,
embedding_quant_bits=embedding_quant_bits,
backend=backend,
)
self._reader = reader
def map_to_actions(
self, layer_name: str
) -> Optional[List[converter_base.QuantizationAction]]:
"""Map the given layer name to actions."""
tensor_value = self._reader.read_tensor_as_numpy(layer_name)
quantize_axis = None
quantize_bits = None
layer_type = LayerType.get_layer_type(layer_name)
if (
layer_type != LayerType.LAYER_NORM
and layer_name.endswith(".weight")
and layer_type != LayerType.LORA
):
quantize_axis = [0]
if layer_type == LayerType.FEEDFORWARD:
quantize_bits = self._feedforward_quant_bits
elif layer_type == LayerType.ATTENTION:
quantize_bits = self._attention_quant_bits
if "o_proj" in layer_name:
tensor_value = np.transpose(tensor_value)
quantize_axis = [1]
elif layer_type == LayerType.EMBEDDING:
quantize_bits = self._embedding_quant_bits
target_name = self.update_target_name(layer_name)
actions = [
converter_base.QuantizationAction(
tensor_name=layer_name,
tensor_value=tensor_value,
target_name=target_name,
quantize_axis=quantize_axis,
quantize_bits=quantize_bits,
pack_dim=0,
)
]
return actions
def update_target_name(self, target_name: str) -> str:
"""Updates the target name to match the tensor name convention."""
target_name = target_name.replace("base_model.model.", "")
target_name = target_name.replace(
"model.layers.", "params.lm.transformer.x_layers_"
)
target_name = target_name.replace("mlp.up_proj", "ff_layer.ffn_layer1")
target_name = target_name.replace("mlp.down_proj", "ff_layer.ffn_layer2")
target_name = target_name.replace(
"mlp.gate_proj", "ff_layer.ffn_layer1_gate"
)
target_name = target_name.replace("input_layernorm", "pre_layer_norm")
target_name = target_name.replace(
"pre_layer_norm.weight", "pre_layer_norm.scale"
)
target_name = target_name.replace(
"post_attention_layernorm", "ff_layer.pre_layer_norm"
)
target_name = target_name.replace(
"ff_layer.pre_layer_norm.weight", "ff_layer.pre_layer_norm.scale"
)
target_name = target_name.replace("self_attn.q_proj", "self_attention.q")
target_name = target_name.replace("self_attn.k_proj", "self_attention.k")
target_name = target_name.replace("self_attn.v_proj", "self_attention.v")
target_name = target_name.replace("self_attn.o_proj", "self_attention.post")
target_name = target_name.replace(
"model.embed_tokens", "params.lm.softmax.logits_ffn"
)
target_name = target_name.replace("model.norm", "params.lm.final_ln")
target_name = target_name.replace("final_ln.weight", "final_ln.scale")
target_name = target_name.replace(".weight", ".w")
# For LoRA weights
if "post" in target_name:
target_name = target_name.replace("lora_A.w", "w_prime_right")
target_name = target_name.replace("lora_B.w", "w_prime_left")
else:
target_name = target_name.replace("lora_A.w", "w_prime_left")
target_name = target_name.replace("lora_B.w", "w_prime_right")
return target_name
class SafetensorsCkptLoader(converter_base.CkptLoaderBase):
"""CkptLoader implementation for loading the Safetensors."""
def __init__(
self,
ckpt_path: str,
is_symmetric: bool,
attention_quant_bits: int,
feedforward_quant_bits: int,
embedding_quant_bits: int,
special_model: str,
backend: str,
):
"""Initializes the loader.
Args:
ckpt_path: The filepath to the safetensors file.
is_symmetric: Whether to apply symmetric or asymmetric quantization.
attention_quant_bits: An integer that specify the target quantization bits
(support 8 or 4) for the attention layers.
feedforward_quant_bits: An integer that specify the target quantization
bits (support 8 or 4) for the feedforward layers in each Transformer
blocks.
embedding_quant_bits: An integer that specify the target quantization bits
(support 8 or 4) for the embedding (and the final projection) layers.
special_model: A string that indicates which input model is and whether
any special treatment is needed.
backend: A string indicating the backend used when converting this model.
Valid options are "cpu" and "gpu".
"""
super().__init__(
ckpt_path,
is_symmetric,
attention_quant_bits,
feedforward_quant_bits,
embedding_quant_bits,
)
self._special_model = special_model
self._reader = _SafetensorsReader(ckpt_path)
if special_model in ["STABLELM_4E1T_3B"]:
self.mapper = StablelmMapper(
is_symmetric,
attention_quant_bits,
feedforward_quant_bits,
embedding_quant_bits,
backend,
self._reader,
)
elif special_model in ["PHI_2"]:
self.mapper = PhiMapper(
is_symmetric,
attention_quant_bits,
feedforward_quant_bits,
embedding_quant_bits,
backend,
self._reader,
)
elif special_model in ["GEMMA_2B", "GEMMA_7B"]:
self.mapper = GemmaMapper(
is_symmetric,
attention_quant_bits,
feedforward_quant_bits,
embedding_quant_bits,
backend,
self._reader,
)
else:
raise ValueError(f"Unknown special model: {special_model}")
def load_to_actions(
self,
) -> Iterator[List[converter_base.QuantizationAction]]:
tensor_names = self._reader.get_tensor_names()
for tensor_name in tensor_names:
tensor_actions = self.mapper.map_to_actions(tensor_name)
if tensor_actions is None:
continue
yield tensor_actions
@@ -0,0 +1,83 @@
# Copyright 2024 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for safetensors_converter."""
import os
from absl.testing import absltest
from absl.testing import parameterized
from mediapipe.tasks.python.genai.converter import safetensors_converter
from mediapipe.tasks.python.test import test_utils
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/text'
_SAFETENSORS_FILE = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, 'stablelm_3b_4e1t_test_weight.safetensors')
)
class SafetensorsConverterTest(parameterized.TestCase):
VARIABLE_NAMES = [
'model.embed_tokens.weight',
'model.layers.0.input_layernorm.bias',
'model.layers.0.input_layernorm.weight',
'model.layers.0.mlp.down_proj.weight',
'model.layers.0.mlp.gate_proj.weight',
'model.layers.0.mlp.up_proj.weight',
'model.layers.0.post_attention_layernorm.bias',
'model.layers.0.post_attention_layernorm.weight',
'model.layers.0.self_attn.k_proj.weight',
'model.layers.0.self_attn.o_proj.weight',
'model.layers.0.self_attn.q_proj.weight',
'model.layers.0.self_attn.v_proj.weight',
'model.norm.bias',
'model.norm.weight',
'lm_head.weight',
]
def test_init(self):
loader = safetensors_converter.SafetensorsCkptLoader(
ckpt_path=_SAFETENSORS_FILE,
is_symmetric=True,
attention_quant_bits=8,
feedforward_quant_bits=8,
embedding_quant_bits=8,
special_model='STABLELM_4E1T_3B',
backend='gpu',
)
self.assertEqual(loader._ckpt_path, _SAFETENSORS_FILE)
self.assertEqual(loader._is_symmetric, True)
self.assertEqual(loader._attention_quant_bits, 8)
self.assertEqual(loader._feedforward_quant_bits, 8)
@parameterized.product(
quant_bits=(4, 8),
)
def test_load_to_actions(self, quant_bits):
loader = safetensors_converter.SafetensorsCkptLoader(
ckpt_path=_SAFETENSORS_FILE,
is_symmetric=True,
attention_quant_bits=8,
feedforward_quant_bits=quant_bits,
embedding_quant_bits=8,
special_model='STABLELM_4E1T_3B',
backend='gpu',
)
actions = loader.load_to_actions()
self.assertLen(list(actions), 15)
if __name__ == '__main__':
absltest.main()
@@ -0,0 +1,112 @@
# Copyright 2024 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ModelWriter for writing a set of weights as binary files."""
import contextlib
import os
from typing import Dict, Tuple
from jax import numpy as jnp
import numpy as np
from mediapipe.tasks.python.genai.converter import converter_base
from mediapipe.tasks.python.genai.converter import quantization_util
@contextlib.contextmanager
def filemanager(filename: str, mode: str):
try:
with open(filename, mode) as f:
yield f
finally:
pass
def removeprefix(s, prefix):
"""Removes the prefix from a string."""
if s.startswith(prefix):
return s[len(prefix) :]
return s
class WeightBinsWriter(converter_base.ModelWriterBase):
"""A ModelWriter for writing a set of weights as binary files."""
def get_weight_info(self, var_name: str, weight: np.ndarray) -> str:
"""Gets the string that describes the weights."""
dtype_str = str(weight.dtype)
shape_str = '_'.join(map(str, weight.shape))
return f'mdl_vars.{var_name}.{dtype_str}.{shape_str}\n'
def write_variables(self, variables: Dict[str, Tuple[np.ndarray, bool]]):
"""Writes variable to the binary files. One for each layer.
Args:
variables: A dictionary that maps from the target variable names to the
quantized tensor values along with a boolean that indicates whether to
pack the values (only applicable for the 4-bit quantized tensors).
"""
weights_info = []
for var_name, value in variables.items():
output = value[0]
if value[1]:
# Squeeze the tensor to make sure it is a 1D array for packing.
output = np.expand_dims(np.ravel(output), axis=-1)
# Extra pack needed for 4 bit. We always pack the weights along the
# first dimension since the tensor has already been squeezed.
output = quantization_util.pack_4bit(output, 0, jnp.int8)
if 'combined_qkv' in var_name:
var_name = removeprefix(var_name, 'mld_vars.')
var_name_q = var_name.replace('combined_qkv', 'q')
var_name_k = var_name.replace('combined_qkv', 'k')
var_name_v = var_name.replace('combined_qkv', 'v')
if output.shape[0] == 3:
weight_q, weight_k, weight_v = output
assert weight_q.shape == weight_k.shape == weight_v.shape
else: # LoRA right weight is shared across q, k, v
weight_q = weight_k = weight_v = output
weights_info.append(self.get_weight_info(var_name_q, weight_q))
path_q = os.path.join(self._output_dir, var_name_q)
with filemanager(path_q, 'wb') as f:
f.write(weight_q.tobytes())
weights_info.append(self.get_weight_info(var_name_k, weight_k))
path_k = os.path.join(self._output_dir, var_name_k)
with filemanager(path_k, 'wb') as f:
f.write(weight_k.tobytes())
path_v = os.path.join(self._output_dir, var_name_v)
with filemanager(path_v, 'wb') as f:
f.write(weight_v.tobytes())
weights_info.append(self.get_weight_info(var_name_v, weight_v))
else:
if 'key' in var_name:
var_name = var_name.replace('key', 'k')
if 'query' in var_name:
var_name = var_name.replace('query', 'q')
if 'value' in var_name:
var_name = var_name.replace('value', 'v')
path = os.path.join(
self._output_dir, removeprefix(var_name, 'mdl_vars.')
)
with filemanager(path, 'wb') as f:
f.write(output.tobytes())
weights_info.append(self.get_weight_info(var_name, output))
# Sort weights_info
weights_info.sort()
with filemanager(
os.path.join(self._output_dir, 'layer_info.txt'), 'a'
) as finfo:
for line in weights_info:
finfo.write(line + '\n')
@@ -0,0 +1,62 @@
# Copyright 2024 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for pax_converter."""
import os
from absl import flags
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from mediapipe.tasks.python.genai.converter import weight_bins_writer
class WeightBinsWriterTest(parameterized.TestCase):
def test_get_weight_info(self):
output_dir = os.path.join(flags.FLAGS.test_tmpdir, 'output_dir')
writer = weight_bins_writer.WeightBinsWriter(
output_dir=output_dir, backend='cpu'
)
var_name = 'params.lm.softmax.logits_ffn.linear.w'
weight_info = writer.get_weight_info(
var_name, np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
)
self.assertEqual(
weight_info,
'mdl_vars.params.lm.softmax.logits_ffn.linear.w.float32.2_3\n',
)
def test_load_to_actions(self):
output_dir = os.path.join(flags.FLAGS.test_tmpdir, 'output_dir')
writer = weight_bins_writer.WeightBinsWriter(
output_dir=output_dir, backend='cpu'
)
variables = {
'mdl_vars.params.lm.softmax.logits_ffn.linear.w': (
np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32),
False,
),
}
writer.write_variables(variables)
file_size = os.path.getsize(
os.path.join(output_dir, 'params.lm.softmax.logits_ffn.linear.w')
)
self.assertEqual(file_size, 6 * 4)
if __name__ == '__main__':
absltest.main()