hand
This commit is contained in:
@@ -0,0 +1,14 @@
|
||||
# Copyright 2024 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""MediaPipe GenAI Python API."""
|
||||
BIN
Binary file not shown.
@@ -0,0 +1,23 @@
|
||||
# Copyright 2024 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""MediaPipe Tasks GenAI Bundler API."""
|
||||
|
||||
import mediapipe.tasks.python.genai.bundler.llm_bundler
|
||||
|
||||
BundleConfig = llm_bundler.BundleConfig
|
||||
create_bundle = llm_bundler.create_bundle
|
||||
|
||||
# Remove unnecessary modules to avoid duplication in API docs.
|
||||
del llm_bundler
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,84 @@
|
||||
# Copyright 2024 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Functions to perform llm packing."""
|
||||
|
||||
import dataclasses
|
||||
import enum
|
||||
from typing import List
|
||||
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import model_asset_bundle_utils
|
||||
from mediapipe.tasks.cc.genai.inference.proto import llm_params_pb2
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class BundleConfig:
|
||||
"""Config for LLM Bundler.
|
||||
|
||||
Attributes:
|
||||
tflite_model: Path to the multi-signature tflite model with "prefill" and
|
||||
"decode" signatures converted using ODML Transformers APIs.
|
||||
tokenizer_model: Path to the tokenizer model. Currently only SentencePience
|
||||
tokenizer is supported. As such, tokenizer.model proto is expected to be
|
||||
passed here.
|
||||
start_token: Token that will be used to signify the beginning of a sequence.
|
||||
stop_tokens: Tokens that will be used to signify the end of a sequence.
|
||||
output_filename: Name of the generated `.task` file containg the Bundle.
|
||||
enable_bytes_to_unicode_mapping: Enables GPT-2 style bytes to unicode
|
||||
mapping. For more details see:
|
||||
https://github.com/openai/gpt-2/blob/master/src/encoder.py#L9
|
||||
"""
|
||||
|
||||
tflite_model: str
|
||||
tokenizer_model: str
|
||||
start_token: str
|
||||
stop_tokens: List[str]
|
||||
output_filename: str
|
||||
enable_bytes_to_unicode_mapping: bool = False
|
||||
|
||||
|
||||
class _BundleTags(enum.Enum):
|
||||
"""Bundle tags."""
|
||||
|
||||
TF_LITE_PREFILL_DECODE = 1
|
||||
TOKENIZER_MODEL = 2
|
||||
METADATA = 3
|
||||
|
||||
|
||||
def create_bundle(config: BundleConfig):
|
||||
"""Creates a bundle from the given config."""
|
||||
artifacts = {}
|
||||
with open(config.tflite_model, "rb") as f:
|
||||
artifacts[_BundleTags.TF_LITE_PREFILL_DECODE.name] = f.read()
|
||||
|
||||
with open(config.tokenizer_model, "rb") as f:
|
||||
artifacts[_BundleTags.TOKENIZER_MODEL.name] = f.read()
|
||||
|
||||
params = llm_params_pb2.LlmParameters()
|
||||
params.start_token = config.start_token
|
||||
params.stop_tokens.extend(config.stop_tokens)
|
||||
if config.enable_bytes_to_unicode_mapping:
|
||||
params.input_output_normalizations.append(
|
||||
llm_params_pb2.LlmParameters.INPUT_OUTPUT_NORMALIZATION_BYTES_TO_UNICODE
|
||||
)
|
||||
artifacts[_BundleTags.METADATA.name] = params.SerializeToString()
|
||||
|
||||
output_filename = config.output_filename
|
||||
if not output_filename.endswith(".task"):
|
||||
output_filename = config.output_filename + ".task"
|
||||
|
||||
model_asset_bundle_utils.create_model_asset_bundle(
|
||||
artifacts,
|
||||
output_filename,
|
||||
)
|
||||
+64
@@ -0,0 +1,64 @@
|
||||
# Copyright 2024 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for llm_bundler."""
|
||||
|
||||
import os
|
||||
import zipfile
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
from mediapipe.tasks.python.genai.bundler import llm_bundler
|
||||
|
||||
|
||||
class LlmBundlerTest(absltest.TestCase):
|
||||
|
||||
def _create_test_bundle(self, out_dir: str):
|
||||
"""Helper function to create test bundle."""
|
||||
tflite_file_path = os.path.join(out_dir, "test.tflite")
|
||||
with open(tflite_file_path, "w") as f:
|
||||
f.write("tflite_model")
|
||||
sp_model_file_path = os.path.join(out_dir, "sp.model")
|
||||
with open(sp_model_file_path, "w") as f:
|
||||
f.write("sp_model")
|
||||
|
||||
output_file = os.path.join(out_dir, "test.task")
|
||||
config = llm_bundler.BundleConfig(
|
||||
tflite_model=tflite_file_path,
|
||||
tokenizer_model=sp_model_file_path,
|
||||
start_token="BOS",
|
||||
stop_tokens=["EOS1", "EOS2"],
|
||||
output_filename=output_file,
|
||||
enable_bytes_to_unicode_mapping=True,
|
||||
)
|
||||
llm_bundler.create_bundle(config)
|
||||
return output_file
|
||||
|
||||
def test_can_create_bundle(self):
|
||||
tempdir = self.create_tempdir()
|
||||
output_file = self._create_test_bundle(tempdir.full_path)
|
||||
self.assertTrue(os.path.exists(output_file))
|
||||
|
||||
def test_verify_content(self):
|
||||
tempdir = self.create_tempdir()
|
||||
output_file = self._create_test_bundle(tempdir.full_path)
|
||||
with zipfile.ZipFile(output_file) as zip_file:
|
||||
self.assertLen(zip_file.filelist, 3)
|
||||
self.assertEqual(zip_file.filelist[0].filename, "TF_LITE_PREFILL_DECODE")
|
||||
self.assertEqual(zip_file.filelist[1].filename, "TOKENIZER_MODEL")
|
||||
self.assertEqual(zip_file.filelist[2].filename, "METADATA")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
@@ -0,0 +1,24 @@
|
||||
# Copyright 2022 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""MediaPipe Tasks GenAI Converter API."""
|
||||
|
||||
import mediapipe.tasks.python.genai.converter.llm_converter
|
||||
|
||||
ConversionConfig = llm_converter.ConversionConfig
|
||||
convert_checkpoint = llm_converter.convert_checkpoint
|
||||
|
||||
# Remove unnecessary modules to avoid duplication in API docs.
|
||||
del mediapipe
|
||||
del llm_converter
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
+175
@@ -0,0 +1,175 @@
|
||||
# Copyright 2024 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Defines a couple base classes for the conversion/quantization process."""
|
||||
|
||||
from typing import Iterator
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import numpy as np
|
||||
|
||||
|
||||
class QuantizationAction:
|
||||
"""Container of the tensor values and its corresponding quantization settings.
|
||||
|
||||
The contrainer is responsible for hosting all of the information that is
|
||||
required to execute the weight-only quantization.
|
||||
|
||||
Attributes:
|
||||
tensor_name: A string that represents the input tensor name.
|
||||
tensor_value: A numpy array that contains the unquantized tensor values.
|
||||
target_name: A string that represents the updated tensor name.
|
||||
quantize_axis: A list of integers representing the dimensions to be
|
||||
quantized along. For example, if an input tensor has shape [128, 256] and
|
||||
the quantize_axis==[0], it means the quantization happens along the 0-th
|
||||
dimension, resulting in [256] scaling factors.
|
||||
quantize_bits: An integer that specifies the target quantization bits. It
|
||||
currently only supports either 8 or 4 bits.
|
||||
pack_dim: An integer specifying which dimension to pack the quantized bits.
|
||||
This is only applicable when the quantize_bits == 4.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tensor_name: str,
|
||||
tensor_value: Optional[np.ndarray] = None,
|
||||
target_name: Optional[str] = None,
|
||||
quantize_axis: Optional[List[int]] = None,
|
||||
quantize_bits: Optional[int] = None,
|
||||
pack_dim: Optional[int] = 0,
|
||||
):
|
||||
"""Initializes the model attributes."""
|
||||
self.tensor_name = tensor_name
|
||||
self.tensor_value = tensor_value
|
||||
self.target_name = target_name
|
||||
self.quantize_axis = quantize_axis
|
||||
self.quantize_bits = quantize_bits
|
||||
self.pack_dim = pack_dim
|
||||
|
||||
def __str__(self) -> str:
|
||||
output_string = "QuantizationAction(\n"
|
||||
output_string += f" tensor_name: {self.tensor_name}\n"
|
||||
output_string += f" target_name: {self.target_name}\n"
|
||||
output_string += f" quantize_axis: {self.quantize_axis}\n"
|
||||
output_string += f" quantize_bits: {self.quantize_bits}\n"
|
||||
output_string += f" pack_dim: {self.pack_dim}\n"
|
||||
if self.tensor_value is not None:
|
||||
output_string += f" tensor_value: {self.tensor_value.shape}\n"
|
||||
output_string += ")\n"
|
||||
return output_string
|
||||
|
||||
|
||||
class CkptLoaderBase:
|
||||
"""Base class for loading the checkpoint.
|
||||
|
||||
This class is responsible for loading the checkpoint files into the layer
|
||||
weight tensors (as numpy arrays) + quantization setting information (8/4
|
||||
bits). The returned data should be a list of QuantizationAction that describes
|
||||
how to quantize each layer weights.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ckpt_path: str,
|
||||
is_symmetric: bool,
|
||||
attention_quant_bits: int,
|
||||
feedforward_quant_bits: int,
|
||||
embedding_quant_bits: int,
|
||||
):
|
||||
"""Initializes the loader.
|
||||
|
||||
Args:
|
||||
ckpt_path: The filepath to the checkpoint.
|
||||
is_symmetric: Whether to apply symmetric or asymmetric quantization.
|
||||
attention_quant_bits: An integer that specify the target quantization bits
|
||||
(support 8 or 4) for the attention layers.
|
||||
feedforward_quant_bits: An integer that specify the target quantization
|
||||
bits (support 8 or 4) for the feedforward layers in each Transformer
|
||||
blocks.
|
||||
embedding_quant_bits: An integer that specify the target quantization bits
|
||||
(support 8 or 4) for the embedding (and the final projection) layers.
|
||||
"""
|
||||
self._ckpt_path = ckpt_path
|
||||
self._is_symmetric = is_symmetric
|
||||
self._attention_quant_bits = attention_quant_bits
|
||||
self._feedforward_quant_bits = feedforward_quant_bits
|
||||
self._embedding_quant_bits = embedding_quant_bits
|
||||
|
||||
def load_to_actions(
|
||||
self,
|
||||
) -> Iterator[Optional[List[QuantizationAction]]]:
|
||||
"""Loads the checkpoint and returns the quantization actions."""
|
||||
raise NotImplementedError("The load_to_actions method is not implemented.")
|
||||
|
||||
|
||||
class LayerActionMapperBase:
|
||||
"""Base class for mapping the layer weights to quantization actions.
|
||||
|
||||
This class is responsible for mapping from each layer to its corresponding
|
||||
quantization information (e.g. target quantization bits / updated tensor
|
||||
name...).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_symmetric: bool,
|
||||
attention_quant_bits: int,
|
||||
feedforward_quant_bits: int,
|
||||
embedding_quant_bits: int,
|
||||
backend: str,
|
||||
):
|
||||
self._is_symmetric = is_symmetric
|
||||
self._attention_quant_bits = attention_quant_bits
|
||||
self._feedforward_quant_bits = feedforward_quant_bits
|
||||
self._embedding_quant_bits = embedding_quant_bits
|
||||
self._backend = backend
|
||||
|
||||
def map_to_actions(
|
||||
self, layer_name: str
|
||||
) -> Optional[List[QuantizationAction]]:
|
||||
"""Maps the layer weights to quantization actions.
|
||||
|
||||
Args:
|
||||
layer_name: A string representing the name of the layer weight. Note that
|
||||
it is expected the layer information is contained in the name which is
|
||||
enough to determine the target quantization information. Any child class
|
||||
is expected to implement this function.
|
||||
"""
|
||||
raise NotImplementedError("The map_to_actions method is not implemented.")
|
||||
|
||||
|
||||
class ModelWriterBase:
|
||||
"""Base class for writing the quantized model.
|
||||
|
||||
This class is responsible for taking a dictionary of the quantized
|
||||
tensors/names and writing them into the format that can be loaded by the
|
||||
on-device inference engine.
|
||||
"""
|
||||
|
||||
def __init__(self, output_dir: str, backend: str):
|
||||
"""Initializes the class.
|
||||
|
||||
Args:
|
||||
output_dir: A string that represents the output directory to write the
|
||||
resulting file(s).
|
||||
backend: A string that represents the target backend to run the output
|
||||
file(s).
|
||||
"""
|
||||
self._output_dir = output_dir
|
||||
if not os.path.exists(self._output_dir):
|
||||
os.mkdir(self._output_dir)
|
||||
self._backend = backend
|
||||
|
||||
def write_variables(self, variables: Dict[str, Tuple[np.ndarray, bool]]):
|
||||
raise NotImplementedError("The write_variables method is not implemented.")
|
||||
+79
@@ -0,0 +1,79 @@
|
||||
# Copyright 2024 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Utility library that helps create the converter instances."""
|
||||
from mediapipe.tasks.python.genai.converter import converter_base
|
||||
from mediapipe.tasks.python.genai.converter import pytorch_converter
|
||||
from mediapipe.tasks.python.genai.converter import safetensors_converter
|
||||
from mediapipe.tasks.python.genai.converter import weight_bins_writer
|
||||
|
||||
|
||||
def create_ckpt_loader(
|
||||
ckpt_format: str, *args, **kwargs
|
||||
) -> converter_base.CkptLoaderBase:
|
||||
"""Creates the checkpoint loader.
|
||||
|
||||
Args:
|
||||
ckpt_format: A string that indicates which input checkpoint format is.
|
||||
*args: Additional arguments to be passed into the loader.
|
||||
**kwargs: Additional arguments to be passed into the loader.
|
||||
|
||||
Returns:
|
||||
A created CkptLoader instance.
|
||||
"""
|
||||
del args
|
||||
if ckpt_format == "pytorch":
|
||||
return pytorch_converter.PytorchCkptLoader(
|
||||
ckpt_path=kwargs["ckpt_path"],
|
||||
is_symmetric=kwargs["is_symmetric"],
|
||||
attention_quant_bits=kwargs["attention_quant_bits"],
|
||||
feedforward_quant_bits=kwargs["feedforward_quant_bits"],
|
||||
embedding_quant_bits=kwargs["embedding_quant_bits"],
|
||||
special_model=kwargs["special_model"],
|
||||
backend=kwargs["backend"],
|
||||
)
|
||||
elif ckpt_format == "safetensors":
|
||||
return safetensors_converter.SafetensorsCkptLoader(
|
||||
ckpt_path=kwargs["ckpt_path"],
|
||||
is_symmetric=kwargs["is_symmetric"],
|
||||
attention_quant_bits=kwargs["attention_quant_bits"],
|
||||
feedforward_quant_bits=kwargs["feedforward_quant_bits"],
|
||||
embedding_quant_bits=kwargs["embedding_quant_bits"],
|
||||
special_model=kwargs["special_model"],
|
||||
backend=kwargs["backend"],
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown checkpoint format: {ckpt_format}")
|
||||
|
||||
|
||||
def create_writer(
|
||||
writer_type: str, *args, **kwargs
|
||||
) -> converter_base.ModelWriterBase:
|
||||
"""Creates the model writer.
|
||||
|
||||
Args:
|
||||
writer_type: A string the indicates which model writer to create.
|
||||
*args: Additional arguments to be passed into the loader.
|
||||
**kwargs: Additional arguments to be passed into the loader.
|
||||
|
||||
Returns:
|
||||
A created ModelWriter instance.
|
||||
"""
|
||||
del args
|
||||
if writer_type == "weight_bins":
|
||||
return weight_bins_writer.WeightBinsWriter(
|
||||
output_dir=kwargs["output_dir"], backend=kwargs["backend"]
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown writer type: {writer_type}")
|
||||
+316
@@ -0,0 +1,316 @@
|
||||
"""Functions to perform the checkpoint conversion."""
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from absl import logging
|
||||
|
||||
from mediapipe.python._framework_bindings import model_ckpt_util
|
||||
from mediapipe.tasks.python.genai.converter import converter_base
|
||||
from mediapipe.tasks.python.genai.converter import converter_factory
|
||||
from mediapipe.tasks.python.genai.converter import quantization_util
|
||||
|
||||
|
||||
class ConversionConfig(object):
|
||||
"""Config for checkpoint conversion.
|
||||
|
||||
Attributes:
|
||||
input_ckpt: Directory or path for the input checkpoint.
|
||||
ckpt_format: Checkpoint format, e.g. 'safetensors', 'pytorch'.
|
||||
model_type: Name of the model, e.g. GEMMA_2B.
|
||||
backend: Target backend to run the model. Can be either "cpu" or "gpu".
|
||||
output_dir: Where the output file(s) to be stored.
|
||||
is_symmetric: Whether to quantize symmetrically.
|
||||
attention_quant_bits: Target quantization bits for the attention layers.
|
||||
feedforward_quant_bits: Target quantization bits for the feedforward layers.
|
||||
embedding_quant_bits: Target quantization bits for the embedding layers.
|
||||
combine_file_only: Whether to combine the weight files only (assuming the
|
||||
weight files are already existed).
|
||||
vocab_model_file: The file path to the 1) SentencePiece vocab model; 2)
|
||||
Hugging Face BPE tokenizer files; 1) is applicable for the Gemma model and
|
||||
2) is applicable for other models. When 2) is used, the provided path is
|
||||
expected to point to a directory that contains both tokenizer.json and
|
||||
tokenizer_config.json files.
|
||||
output_tflite_file: (optional) the output tflite filename. If not provided,
|
||||
the output will be `model.tflite` stored in the output_dir.
|
||||
fp16_scale: A scalar value between [0, 1]. Some models can run into
|
||||
activation overflow issue when running in 16-bit floating point mode. To
|
||||
solve this, we need to scale down the weights of certain layers. See
|
||||
go/llm-on-device-fp16 for more detailed explanation.
|
||||
lora_ckpt: The directory or path for the lora checkpoint. Required in order
|
||||
to convert the lora weights.
|
||||
lora_rank: An integer representing the rank of LoRA. Required in order to
|
||||
convert the lora weights.If not provided, then the converter assumes there
|
||||
is no LoRA weights. Note that only the GPU backend supports LoRA.
|
||||
lora_output_tflite_file: A string indicating the name of the generated
|
||||
tflite file for the LoRA weight. Only applicable when the lora_rank is not
|
||||
zero.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_ckpt: str,
|
||||
ckpt_format: str,
|
||||
model_type: str,
|
||||
backend: str,
|
||||
output_dir: str,
|
||||
is_symmetric: bool = True,
|
||||
attention_quant_bits: int = 8,
|
||||
feedforward_quant_bits: int = 8,
|
||||
embedding_quant_bits: int = 8,
|
||||
combine_file_only: bool = False,
|
||||
vocab_model_file: str = '',
|
||||
output_tflite_file: Optional[str] = None,
|
||||
fp16_scale: Optional[float] = None,
|
||||
lora_ckpt: Optional[str] = None,
|
||||
lora_rank: Optional[int] = None,
|
||||
lora_output_tflite_file: Optional[str] = None,
|
||||
):
|
||||
self.input_ckpt = input_ckpt
|
||||
self.ckpt_format = ckpt_format
|
||||
self.model_type = model_type
|
||||
self.backend = backend
|
||||
if os.path.isfile(output_dir):
|
||||
raise ValueError('Output directory mush not point to an existing file.')
|
||||
if not os.path.isdir(output_dir):
|
||||
logging.info('Creating output directory: %s', output_dir)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
self.output_dir = output_dir
|
||||
self.is_symmetric = is_symmetric
|
||||
self.attention_quant_bits = attention_quant_bits
|
||||
self.feedforward_quant_bits = feedforward_quant_bits
|
||||
self.embedding_quant_bits = embedding_quant_bits
|
||||
self.combine_file_only = combine_file_only
|
||||
self.vocab_model_file = vocab_model_file
|
||||
if output_tflite_file:
|
||||
parent_dir = os.path.dirname(output_tflite_file)
|
||||
if not os.path.isdir(parent_dir):
|
||||
logging.info('Creating tflite parent directory: %s', parent_dir)
|
||||
os.makedirs(parent_dir, exist_ok=True)
|
||||
self.output_tflite_file = output_tflite_file
|
||||
else:
|
||||
self.output_tflite_file = os.path.join(output_dir, 'model.tflite')
|
||||
|
||||
self.fp16_scale = None
|
||||
self.lora_ckpt = lora_ckpt
|
||||
self.lora_rank = lora_rank
|
||||
self.lora_output_tflite_file = lora_output_tflite_file
|
||||
if (self.lora_ckpt is None) ^ (self.lora_rank is None):
|
||||
raise ValueError(
|
||||
'lora_ckpt and lora_rank must be either both provided or both not'
|
||||
' provided.'
|
||||
)
|
||||
if self.lora_rank is not None:
|
||||
if backend == 'cpu':
|
||||
raise ValueError('LoRA is not supported for CPU backend.')
|
||||
lora_applicable_models = ['GEMMA_2B', 'PHI_2']
|
||||
if model_type not in lora_applicable_models:
|
||||
raise ValueError(
|
||||
'LoRA is only applicable for the model_type:'
|
||||
f' {", ".join(lora_applicable_models)}, but get model_type:'
|
||||
f' {model_type}.'
|
||||
)
|
||||
|
||||
|
||||
def quantize_by_actions(
|
||||
actions: List[converter_base.QuantizationAction],
|
||||
backend: str,
|
||||
is_symmetric: bool,
|
||||
):
|
||||
"""Quantizes the weights by actions.
|
||||
|
||||
Args:
|
||||
actions: A list of QuantizationAction that contains the information and
|
||||
tensor values to be quantized.
|
||||
backend: Target backend to run the model. Can be either "cpu" or "gpu".
|
||||
is_symmetric: Whether to quantize symmetrically.
|
||||
|
||||
Returns:
|
||||
A dictionary that maps from the updated tensor names to the quantized
|
||||
tensor values + a boolean that indicates whether the tensor values need to
|
||||
be packed (only applicable for the 4-bit quantized weights).
|
||||
"""
|
||||
output_tensors = {}
|
||||
for action in actions:
|
||||
if action.quantize_axis:
|
||||
pack = action.quantize_bits == 4
|
||||
if is_symmetric:
|
||||
target_var, scale = quantization_util.quantize_tensor(
|
||||
var=action.tensor_value,
|
||||
axis=action.quantize_axis,
|
||||
sym=is_symmetric,
|
||||
number_bits=action.quantize_bits,
|
||||
)
|
||||
output_tensors[action.target_name] = (target_var, pack)
|
||||
output_tensors[action.target_name + '_quantized_scale'] = (scale, False)
|
||||
zp = None
|
||||
else:
|
||||
target_var, scale, zp = quantization_util.quantize_tensor(
|
||||
var=action.tensor_value,
|
||||
axis=action.quantize_axis,
|
||||
sym=is_symmetric,
|
||||
number_bits=action.quantize_bits,
|
||||
)
|
||||
if backend == 'cpu' and pack:
|
||||
target_var, scale, zp = quantization_util.update_to_uint4(
|
||||
target_var, scale, zp
|
||||
)
|
||||
output_tensors[action.target_name] = (target_var, pack)
|
||||
output_tensors[action.target_name + '_quantized_scale'] = (scale, False)
|
||||
if zp is not None:
|
||||
output_tensors[action.target_name + '_quantized_zp'] = (zp, False)
|
||||
else:
|
||||
output_tensors[action.target_name] = (action.tensor_value, False)
|
||||
return output_tensors
|
||||
|
||||
|
||||
def combined_weight_bins_to_tflite(
|
||||
model_type: str,
|
||||
backend: str,
|
||||
weight_path: str,
|
||||
output_tflite_file: str,
|
||||
vocab_model_file: str,
|
||||
lora_rank: Optional[int] = None,
|
||||
lora_weight_path: Optional[str] = None,
|
||||
lora_output_tflite_file: Optional[str] = None,
|
||||
):
|
||||
"""Combines weight files to tflite file."""
|
||||
if backend == 'cpu':
|
||||
if lora_rank is not None:
|
||||
logging.fatal('LoRA is not supported for CPU backend.')
|
||||
model_ckpt_util.GenerateCpuTfLite(
|
||||
model_type,
|
||||
weight_path,
|
||||
vocab_model_file,
|
||||
True,
|
||||
output_tflite_file,
|
||||
)
|
||||
elif backend == 'gpu':
|
||||
model_ckpt_util.GenerateGpuTfLite(
|
||||
model_type,
|
||||
weight_path,
|
||||
vocab_model_file,
|
||||
True,
|
||||
output_tflite_file,
|
||||
0 if lora_rank is None else lora_rank,
|
||||
'' if lora_weight_path is None else lora_weight_path,
|
||||
'' if lora_output_tflite_file is None else lora_output_tflite_file,
|
||||
)
|
||||
else:
|
||||
raise ValueError('Unsupported backend: %s' % backend)
|
||||
|
||||
|
||||
def convert_bpe_vocab(vocab_model_file: str, output_dir: str) -> str:
|
||||
if not os.path.isdir(vocab_model_file):
|
||||
raise ValueError(
|
||||
'The input BPE vocab model file path is expected to be a directory that'
|
||||
' conatins both tokenizer.json and tokenizer_config.json files.'
|
||||
)
|
||||
output_vocab_file = os.path.join(output_dir, 'spm.model')
|
||||
model_ckpt_util.ConvertHfTokenizer(vocab_model_file, output_vocab_file)
|
||||
return output_vocab_file
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def filemanager(filename: str, mode: str):
|
||||
try:
|
||||
with open(filename, mode) as f:
|
||||
yield f
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
def sort_layer_info(layer_info_file: str) -> None:
|
||||
"""Loads and sorts the layer info file."""
|
||||
layer_info = []
|
||||
with filemanager(layer_info_file, 'r') as finfo:
|
||||
for line in finfo:
|
||||
line = line.strip()
|
||||
if line:
|
||||
layer_info.append(line)
|
||||
layer_info = list(set(layer_info))
|
||||
layer_info.sort()
|
||||
with filemanager(layer_info_file, 'w') as finfo:
|
||||
for line in layer_info:
|
||||
finfo.write(line + '\n')
|
||||
finfo.write('\n')
|
||||
|
||||
|
||||
def maybe_quantize_and_write_tensors_to_bins(
|
||||
ckpt_loader: converter_base.CkptLoaderBase,
|
||||
config: ConversionConfig,
|
||||
) -> None:
|
||||
"""Quantizes the weight tensors according to the loader and writes them to bins."""
|
||||
actions = ckpt_loader.load_to_actions()
|
||||
|
||||
for action in actions:
|
||||
# Quantize the weight
|
||||
quantized_tensors = quantize_by_actions(
|
||||
action, config.backend, config.is_symmetric
|
||||
)
|
||||
del action
|
||||
# Write the tensors into file(s).
|
||||
writer = converter_factory.create_writer(
|
||||
writer_type='weight_bins',
|
||||
output_dir=config.output_dir,
|
||||
backend=config.backend,
|
||||
)
|
||||
writer.write_variables(quantized_tensors)
|
||||
del quantized_tensors
|
||||
del writer
|
||||
|
||||
|
||||
def convert_checkpoint(config: ConversionConfig) -> None:
|
||||
"""Converts the checkpoint to tflite file."""
|
||||
logging.info('input folder: %s', config.input_ckpt)
|
||||
|
||||
if os.path.isdir(config.vocab_model_file):
|
||||
vocab_model_path = convert_bpe_vocab(
|
||||
config.vocab_model_file, config.output_dir
|
||||
)
|
||||
else:
|
||||
vocab_model_path = config.vocab_model_file
|
||||
|
||||
if not config.combine_file_only:
|
||||
# Load the layer weights and prepare the quantization configurations.
|
||||
loader = converter_factory.create_ckpt_loader(
|
||||
config.ckpt_format,
|
||||
ckpt_path=config.input_ckpt,
|
||||
is_symmetric=config.is_symmetric,
|
||||
backend=config.backend,
|
||||
attention_quant_bits=config.attention_quant_bits,
|
||||
feedforward_quant_bits=config.feedforward_quant_bits,
|
||||
embedding_quant_bits=config.embedding_quant_bits,
|
||||
special_model=config.model_type,
|
||||
fp16_scale=config.fp16_scale,
|
||||
)
|
||||
maybe_quantize_and_write_tensors_to_bins(loader, config)
|
||||
|
||||
if config.lora_ckpt is not None and config.lora_ckpt != config.input_ckpt:
|
||||
# If lora ckpt and the input ckpt is the same. The lora conversion is
|
||||
# handled in the previous loader.
|
||||
lora_loader = converter_factory.create_ckpt_loader(
|
||||
config.ckpt_format,
|
||||
ckpt_path=config.lora_ckpt,
|
||||
is_symmetric=config.is_symmetric,
|
||||
backend=config.backend,
|
||||
attention_quant_bits=None,
|
||||
feedforward_quant_bits=None,
|
||||
embedding_quant_bits=None,
|
||||
special_model=config.model_type,
|
||||
)
|
||||
maybe_quantize_and_write_tensors_to_bins(lora_loader, config)
|
||||
|
||||
sort_layer_info(os.path.join(config.output_dir, 'layer_info.txt'))
|
||||
|
||||
combined_weight_bins_to_tflite(
|
||||
config.model_type,
|
||||
config.backend,
|
||||
weight_path=config.output_dir,
|
||||
output_tflite_file=config.output_tflite_file,
|
||||
vocab_model_file=vocab_model_path,
|
||||
lora_rank=config.lora_rank,
|
||||
lora_weight_path=config.output_dir,
|
||||
lora_output_tflite_file=config.lora_output_tflite_file,
|
||||
)
|
||||
+318
@@ -0,0 +1,318 @@
|
||||
# Copyright 2024 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""CkptLoader implementation for loading the Pytorch file."""
|
||||
|
||||
from typing import Iterator
|
||||
import enum
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from mediapipe.tasks.python.genai.converter import converter_base
|
||||
|
||||
|
||||
class _PytorchReader:
|
||||
"""Pytorch reader."""
|
||||
|
||||
def __init__(self, model_path: str):
|
||||
if not os.path.exists(model_path):
|
||||
raise ValueError(f"{model_path} does not exists.")
|
||||
self._model = torch.load(model_path, map_location=torch.device("cpu"))
|
||||
|
||||
def read_tensor_as_numpy(self, tensor_name) -> np.ndarray:
|
||||
tensor = (
|
||||
self._model[tensor_name]
|
||||
.to(torch.float32)
|
||||
.t()
|
||||
.contiguous()
|
||||
.detach()
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
return tensor
|
||||
|
||||
def get_tensor_names(self) -> List[str]:
|
||||
names = list(self._model.keys())
|
||||
return names
|
||||
|
||||
|
||||
class LayerType(enum.Enum):
|
||||
"""Enum for layer type."""
|
||||
|
||||
NONE = 0
|
||||
ATTENTION = 1 # Layer is part of the attention module.
|
||||
FEEDFORWARD = 2 # Layer is part of the feedforward module in the Transformer.
|
||||
EMBEDDING = 3 # Layer is the embedding lookup or final projection layer.
|
||||
LAYER_NORM = (
|
||||
4 # Layer is layer normalization before and after attention layer.
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_layer_type(cls, layer_name: str):
|
||||
"""Gets the layer type of the given layer name."""
|
||||
ffn_layers = [
|
||||
"mlp",
|
||||
]
|
||||
attn_layers = [
|
||||
"self_attention",
|
||||
]
|
||||
emb_layers = [
|
||||
"word_embeddings",
|
||||
"lm_head",
|
||||
]
|
||||
layer_norms = [
|
||||
"input_layernorm",
|
||||
"post_attention_layernorm",
|
||||
"ln_f",
|
||||
]
|
||||
if any(sub_name in layer_name for sub_name in attn_layers):
|
||||
return LayerType.ATTENTION
|
||||
if any(sub_name in layer_name for sub_name in ffn_layers):
|
||||
return LayerType.FEEDFORWARD
|
||||
if any(sub_name in layer_name for sub_name in emb_layers):
|
||||
return LayerType.EMBEDDING
|
||||
if any(sub_name in layer_name for sub_name in layer_norms):
|
||||
return LayerType.LAYER_NORM
|
||||
else:
|
||||
return LayerType.NONE
|
||||
|
||||
|
||||
class FalconMapper(converter_base.LayerActionMapperBase):
|
||||
"""LayerActionMapper for handling the Falcon-rw-1b model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_symmetric: bool,
|
||||
attention_quant_bits: int,
|
||||
feedforward_quant_bits: int,
|
||||
embedding_quant_bits: int,
|
||||
backend: str,
|
||||
reader: _PytorchReader,
|
||||
):
|
||||
super().__init__(
|
||||
is_symmetric=is_symmetric,
|
||||
attention_quant_bits=attention_quant_bits,
|
||||
feedforward_quant_bits=feedforward_quant_bits,
|
||||
embedding_quant_bits=embedding_quant_bits,
|
||||
backend=backend,
|
||||
)
|
||||
self._reader = reader
|
||||
|
||||
def map_to_actions(
|
||||
self, layer_name: str
|
||||
) -> Optional[List[converter_base.QuantizationAction]]:
|
||||
"""Map the given layer name to actions."""
|
||||
actions = []
|
||||
tensor_value = self._reader.read_tensor_as_numpy(layer_name)
|
||||
if "query_key_value" in layer_name:
|
||||
qkv_tensors = self._decompose_falcon_qkv(tensor_value)
|
||||
for tensor, name in zip(qkv_tensors, ["q", "k", "v"]):
|
||||
decomposed_name = layer_name.replace("query_key_value", name)
|
||||
action = self._map_to_action_helper(tensor, decomposed_name)
|
||||
actions.append(action)
|
||||
else:
|
||||
actions.append(self._map_to_action_helper(tensor_value, layer_name))
|
||||
return actions
|
||||
|
||||
def _map_to_action_helper(
|
||||
self, tensor_value: np.ndarray, layer_name: str
|
||||
) -> converter_base.QuantizationAction:
|
||||
quantize_axis = None
|
||||
quantize_bits = None
|
||||
layer_type = LayerType.get_layer_type(layer_name)
|
||||
|
||||
if layer_type != LayerType.LAYER_NORM and layer_name.endswith(".weight"):
|
||||
layer_type = LayerType.get_layer_type(layer_name)
|
||||
quantize_axis = [0]
|
||||
if layer_type == LayerType.FEEDFORWARD:
|
||||
quantize_bits = self._feedforward_quant_bits
|
||||
elif layer_type == LayerType.ATTENTION:
|
||||
quantize_bits = self._attention_quant_bits
|
||||
if self._backend == "cpu" and ".dense." in layer_name:
|
||||
tensor_value = np.transpose(tensor_value)
|
||||
quantize_axis = [1]
|
||||
elif layer_type == LayerType.EMBEDDING:
|
||||
quantize_bits = self._embedding_quant_bits
|
||||
if self._backend == "cpu" and "word_embeddings" in layer_name:
|
||||
tensor_value = np.transpose(tensor_value)
|
||||
quantize_axis = [1]
|
||||
target_name = self.update_target_name(layer_name)
|
||||
|
||||
return converter_base.QuantizationAction(
|
||||
tensor_name=layer_name,
|
||||
tensor_value=tensor_value,
|
||||
target_name=target_name,
|
||||
quantize_axis=quantize_axis,
|
||||
quantize_bits=quantize_bits,
|
||||
pack_dim=0,
|
||||
)
|
||||
|
||||
def _decompose_falcon_qkv(self, tensor_value: np.ndarray) -> List[np.ndarray]:
|
||||
"""Decomposes combined qkv tensor used in falcon model into separate q, k and v tensors."""
|
||||
chunk_size = 64
|
||||
hidden_size = 2048
|
||||
|
||||
tensor_value = tensor_value.transpose()
|
||||
|
||||
q_tensor = np.zeros(
|
||||
(hidden_size,)
|
||||
+ ((hidden_size,) if len(tensor_value.shape) == 2 else ()),
|
||||
dtype=tensor_value.dtype,
|
||||
)
|
||||
k_tensor = np.zeros_like(q_tensor, dtype=tensor_value.dtype)
|
||||
v_tensor = np.zeros_like(k_tensor, dtype=tensor_value.dtype)
|
||||
|
||||
j = 0
|
||||
for i in range(0 * chunk_size, hidden_size * 3, chunk_size * 3):
|
||||
q_tensor[j : j + chunk_size] = tensor_value[i : i + chunk_size]
|
||||
j += chunk_size
|
||||
|
||||
j = 0
|
||||
for i in range(1 * chunk_size, hidden_size * 3, chunk_size * 3):
|
||||
k_tensor[j : j + chunk_size] = tensor_value[i : i + chunk_size]
|
||||
j += chunk_size
|
||||
|
||||
j = 0
|
||||
for i in range(2 * chunk_size, hidden_size * 3, chunk_size * 3):
|
||||
v_tensor[j : j + chunk_size] = tensor_value[i : i + chunk_size]
|
||||
j += chunk_size
|
||||
|
||||
return [
|
||||
np.ascontiguousarray(q_tensor.transpose()),
|
||||
np.ascontiguousarray(k_tensor.transpose()),
|
||||
np.ascontiguousarray(v_tensor.transpose()),
|
||||
]
|
||||
|
||||
def update_target_name(self, target_name: str) -> str:
|
||||
"""Updates the target name to match the tensor name convention."""
|
||||
layer_type = LayerType.get_layer_type(target_name)
|
||||
|
||||
target_name = target_name.replace(
|
||||
"transformer.h.", "params.lm.transformer.x_layers_"
|
||||
)
|
||||
|
||||
if layer_type == LayerType.FEEDFORWARD:
|
||||
target_name = target_name.replace(".weight", ".linear.w")
|
||||
target_name = target_name.replace(".bias", ".bias.b")
|
||||
target_name = target_name.replace(
|
||||
"mlp.dense_h_to_4h", "ff_layer.ffn_layer1"
|
||||
)
|
||||
target_name = target_name.replace(
|
||||
"mlp.dense_4h_to_h", "ff_layer.ffn_layer2"
|
||||
)
|
||||
elif layer_type == LayerType.ATTENTION:
|
||||
target_name = target_name.replace("dense", "post")
|
||||
target_name = target_name.replace(".weight", ".linear.w")
|
||||
target_name = target_name.replace(".bias", ".bias.b")
|
||||
elif layer_type == LayerType.EMBEDDING:
|
||||
target_name = target_name.replace(
|
||||
"transformer.word_embeddings", "params.lm.token_embedding"
|
||||
)
|
||||
target_name = target_name.replace(
|
||||
"lm_head", "params.lm.softmax.logits_ffn"
|
||||
)
|
||||
target_name = target_name.replace(".weight", ".w")
|
||||
elif layer_type == LayerType.LAYER_NORM:
|
||||
target_name = target_name.replace("input_layernorm", "pre_layer_norm")
|
||||
target_name = target_name.replace(
|
||||
"pre_layer_norm.weight", "pre_layer_norm.scale"
|
||||
)
|
||||
if self._backend == "cpu":
|
||||
target_name = target_name.replace(
|
||||
"post_attention_layernorm", "ff_layer.pre_layer_norm"
|
||||
)
|
||||
target_name = target_name.replace(
|
||||
"ff_layer.pre_layer_norm.weight", "ff_layer.pre_layer_norm.scale"
|
||||
)
|
||||
else:
|
||||
target_name = target_name.replace(
|
||||
"post_attention_layernorm", "post_layer_norm"
|
||||
)
|
||||
target_name = target_name.replace(
|
||||
"post_layer_norm.weight", "post_layer_norm.scale"
|
||||
)
|
||||
target_name = target_name.replace(
|
||||
"transformer.ln_f.weight", "params.lm.final_ln.scale"
|
||||
)
|
||||
target_name = target_name.replace(
|
||||
"transformer.ln_f.bias", "params.lm.final_ln.bias"
|
||||
)
|
||||
|
||||
return target_name
|
||||
|
||||
|
||||
class PytorchCkptLoader(converter_base.CkptLoaderBase):
|
||||
"""CkptLoader implementation for loading the Pytorch model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ckpt_path: str,
|
||||
is_symmetric: bool,
|
||||
attention_quant_bits: int,
|
||||
feedforward_quant_bits: int,
|
||||
embedding_quant_bits: int,
|
||||
special_model: str,
|
||||
backend: str,
|
||||
):
|
||||
"""Initializes the loader.
|
||||
|
||||
Args:
|
||||
ckpt_path: The filepath to the safetensors file.
|
||||
is_symmetric: Whether to apply symmetric or asymmetric quantization.
|
||||
attention_quant_bits: An integer that specify the target quantization bits
|
||||
(support 8 or 4) for the attention layers.
|
||||
feedforward_quant_bits: An integer that specify the target quantization
|
||||
bits (support 8 or 4) for the feedforward layers in each Transformer
|
||||
blocks.
|
||||
embedding_quant_bits: An integer that specify the target quantization bits
|
||||
(support 8 or 4) for the embedding (and the final projection) layers.
|
||||
special_model: A string that indicates which input model is and whether
|
||||
any special treatment is needed.
|
||||
backend: A string indicating the backend used when converting this model.
|
||||
Valid options are "cpu" and "gpu".
|
||||
"""
|
||||
super().__init__(
|
||||
ckpt_path,
|
||||
is_symmetric,
|
||||
attention_quant_bits,
|
||||
feedforward_quant_bits,
|
||||
embedding_quant_bits,
|
||||
)
|
||||
|
||||
self._special_model = special_model
|
||||
self._reader = _PytorchReader(ckpt_path)
|
||||
if special_model in ["FALCON_RW_1B"]:
|
||||
self.mapper = FalconMapper(
|
||||
is_symmetric,
|
||||
attention_quant_bits,
|
||||
feedforward_quant_bits,
|
||||
embedding_quant_bits,
|
||||
backend,
|
||||
self._reader,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown special model: {special_model}")
|
||||
|
||||
def load_to_actions(
|
||||
self,
|
||||
) -> Iterator[List[converter_base.QuantizationAction]]:
|
||||
tensor_names = self._reader.get_tensor_names()
|
||||
for tensor_name in tensor_names:
|
||||
tensor_actions = self.mapper.map_to_actions(tensor_name)
|
||||
if tensor_actions is None:
|
||||
continue
|
||||
yield tensor_actions
|
||||
+86
@@ -0,0 +1,86 @@
|
||||
# Copyright 2024 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for pytorch_converter."""
|
||||
|
||||
import os
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
from mediapipe.tasks.python.genai.converter import pytorch_converter
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/text'
|
||||
_PYTORCH_FILE = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, 'falcon_rw_1b_test_weight.pt')
|
||||
)
|
||||
|
||||
|
||||
class PytorchConverterTest(parameterized.TestCase):
|
||||
VARIABLE_NAMES = [
|
||||
'transformer.word_embeddings.weight',
|
||||
'transformer.h.0.input_layernorm.weight',
|
||||
'transformer.h.0.input_layernorm.bias',
|
||||
'transformer.h.0.self_attention.query_key_value.weight',
|
||||
'transformer.h.0.self_attention.query_key_value.bias',
|
||||
'transformer.h.0.self_attention.dense.weight',
|
||||
'transformer.h.0.self_attention.dense.bias',
|
||||
'transformer.h.0.post_attention_layernorm.weight',
|
||||
'transformer.h.0.post_attention_layernorm.bias',
|
||||
'transformer.h.0.mlp.dense_h_to_4h.weight',
|
||||
'transformer.h.0.mlp.dense_h_to_4h.bias',
|
||||
'transformer.h.0.mlp.dense_4h_to_h.weight',
|
||||
'transformer.h.0.mlp.dense_4h_to_h.bias',
|
||||
'transformer.ln_f.weight',
|
||||
'transformer.ln_f.bias',
|
||||
'lm_head.weight',
|
||||
]
|
||||
|
||||
def test_init(self):
|
||||
loader = pytorch_converter.PytorchCkptLoader(
|
||||
ckpt_path=_PYTORCH_FILE,
|
||||
is_symmetric=True,
|
||||
attention_quant_bits=8,
|
||||
feedforward_quant_bits=8,
|
||||
embedding_quant_bits=8,
|
||||
special_model='FALCON_RW_1B',
|
||||
backend='cpu',
|
||||
)
|
||||
self.assertEqual(loader._ckpt_path, _PYTORCH_FILE)
|
||||
self.assertEqual(loader._is_symmetric, True)
|
||||
self.assertEqual(loader._attention_quant_bits, 8)
|
||||
self.assertEqual(loader._feedforward_quant_bits, 8)
|
||||
|
||||
@parameterized.product(
|
||||
quant_bits=(4, 8),
|
||||
)
|
||||
def test_load_to_actions(self, quant_bits):
|
||||
loader = pytorch_converter.PytorchCkptLoader(
|
||||
ckpt_path=_PYTORCH_FILE,
|
||||
is_symmetric=True,
|
||||
attention_quant_bits=8,
|
||||
feedforward_quant_bits=quant_bits,
|
||||
embedding_quant_bits=8,
|
||||
special_model='FALCON_RW_1B',
|
||||
backend='cpu',
|
||||
)
|
||||
actions = loader.load_to_actions()
|
||||
# There are 16 layers in the model, but qkv weight and bias would be
|
||||
# decomposed to q, k, v tensors, so there would be 20 quantization actions.
|
||||
self.assertEqual(sum(len(action) for action in actions), 20)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
+516
@@ -0,0 +1,516 @@
|
||||
# Copyright 2024 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Utilities for quantizing tensors.
|
||||
|
||||
Note that this is a reduced fork version of the praxis libraries to provide a
|
||||
self-contained library for packaging.
|
||||
"""
|
||||
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
JTensor = jax.Array
|
||||
_UINT4_ZP = 8 # Default zero point for unsigned 4-bit.
|
||||
|
||||
|
||||
def _get_scan_range() -> np.ndarray:
|
||||
# Produce candidate scan values.
|
||||
return np.linspace(1.0, 0.5, num=11)
|
||||
|
||||
|
||||
def _get_mean_error(bound, t, min_value, max_value, p_value):
|
||||
scale = bound / max_value
|
||||
candidate = jnp.divide(t, scale)
|
||||
candidate = jnp.clip(jnp.round(candidate), min_value, max_value)
|
||||
candidate = jnp.multiply(candidate, scale)
|
||||
pmean_error = jnp.mean(jnp.abs(jnp.subtract(candidate, t)) ** p_value)
|
||||
return pmean_error
|
||||
|
||||
|
||||
def _get_best_bound_per_tensor(
|
||||
t: JTensor,
|
||||
bound: JTensor,
|
||||
min_value: float,
|
||||
max_value: float,
|
||||
p_value: float = 1.0,
|
||||
) -> JTensor:
|
||||
"""Scan around [0.5, 1] * hard max value to get bound value for whole tensor.
|
||||
|
||||
This does a scan to get bound value(s) that minimize mean absolute error (MAE)
|
||||
between original tensor 't' and quantized tensor. It's (almost) equivalent to
|
||||
maximizing entropy.
|
||||
|
||||
Args:
|
||||
t: The input float tensor.
|
||||
bound: The hard max value for tensor 't'. It has the same length as shape.
|
||||
min_value: Minimal value for the quantization bound.
|
||||
max_value: Maximal value for the quantization bound.
|
||||
p_value: Exponent of the p-mean error metric. Default to 1.0 which is MAE.
|
||||
|
||||
Returns:
|
||||
The best bound values for 't', that minimize p-mean error.
|
||||
"""
|
||||
|
||||
def _quant(scaling_factors):
|
||||
return _get_mean_error(
|
||||
bound * scaling_factors, t, min_value, max_value, p_value
|
||||
)
|
||||
|
||||
scaling_factors = _get_scan_range()
|
||||
diffs = jax.vmap(_quant)(scaling_factors)
|
||||
best_scaling = scaling_factors[jnp.argmin(diffs)].astype(bound.dtype)
|
||||
return bound * best_scaling
|
||||
|
||||
|
||||
def _quantrow(
|
||||
vec: JTensor,
|
||||
bound: JTensor,
|
||||
min_value: float,
|
||||
max_value: float,
|
||||
p_value: float,
|
||||
factors: np.ndarray,
|
||||
) -> JTensor:
|
||||
"""Get best rescaling factor from a list of factors applied a channel.
|
||||
|
||||
Args:
|
||||
vec: The vector in a channel.
|
||||
bound: The hard bound (max(abs(vec))) of the vector.
|
||||
min_value: The target min value.
|
||||
max_value: The target max value.
|
||||
p_value: Exponent of the p-mean error metric.
|
||||
factors: The values to be applied on top of bound.
|
||||
|
||||
Returns:
|
||||
adjusted bound value out of the list of factors applied to bound.
|
||||
"""
|
||||
|
||||
def _quant(bounds):
|
||||
return _get_mean_error(bounds, vec, min_value, max_value, p_value)
|
||||
|
||||
diffs = jax.vmap(_quant)(bound * factors)
|
||||
best_scaling = factors[jnp.argmin(diffs)]
|
||||
return bound * best_scaling
|
||||
|
||||
|
||||
def _get_best_bound_per_channel(
|
||||
t: JTensor,
|
||||
bound: JTensor,
|
||||
min_value: float,
|
||||
max_value: float,
|
||||
p_value: float = 1.0,
|
||||
) -> JTensor:
|
||||
"""Scan around [0.5, 1] * hard max value to get bound value for each channel.
|
||||
|
||||
This does a scan to get bound value(s) that minimize mean absolute error (MAE)
|
||||
between original tensor 't' and quantized tensor. It's (almost) equivalent to
|
||||
maximizing entropy.
|
||||
|
||||
Args:
|
||||
t: The input float tensor.
|
||||
bound: The hard max value for tensor 't'. It has the same length as shape.
|
||||
min_value: Minimal value for the quantization bound.
|
||||
max_value: Maximal value for the quantization bound.
|
||||
p_value: Exponent of the p-mean error metric. Default to 1.0 which is MAE.
|
||||
|
||||
Returns:
|
||||
The best bound values for 't', that minimize p-mean error.
|
||||
"""
|
||||
assert len(t.shape) == 2
|
||||
assert len(bound.shape) == 2
|
||||
assert t.shape[1] == bound.shape[1]
|
||||
assert bound.shape[0] == 1
|
||||
scans = _get_scan_range()
|
||||
|
||||
def _quant(tensor, bound, min_value, max_value, p_value, factors):
|
||||
ret = np.zeros(bound.shape)
|
||||
for i in range(len(tensor)):
|
||||
best = _quantrow(
|
||||
tensor[i], bound[i], min_value, max_value, p_value, factors
|
||||
)
|
||||
ret[i] = best
|
||||
return ret
|
||||
|
||||
t = t.transpose()
|
||||
t_split = list(t)
|
||||
res = _quant(t_split, bound[0, :], min_value, max_value, p_value, scans)
|
||||
res = res.reshape(bound.shape)
|
||||
return res
|
||||
|
||||
|
||||
def get_best_bound(
|
||||
t: JTensor,
|
||||
bound: JTensor,
|
||||
min_value: float,
|
||||
max_value: float,
|
||||
p_value: float = 1.0,
|
||||
per_channel: bool = False,
|
||||
) -> JTensor:
|
||||
"""Scan mutliple factors on max value to get best bound value.
|
||||
|
||||
This does a scan to get bound value(s) that minimize mean absolute error (MAE)
|
||||
between original tensor 't' and quantized tensor. It's (almost) equivalent to
|
||||
maximizing entropy.
|
||||
|
||||
Args:
|
||||
t: The input float tensor.
|
||||
bound: The hard max value for tensor 't'. It has the same length as shape.
|
||||
min_value: Minimal value for the quantization bound.
|
||||
max_value: Maximal value for the quantization bound.
|
||||
p_value: Exponent of the p-mean error metric. Default to 1.0 which is MAE.
|
||||
per_channel: if get best bound for entire tensor or per channel.
|
||||
|
||||
Returns:
|
||||
The best bound values for 't', that minimize p-mean error.
|
||||
"""
|
||||
if per_channel:
|
||||
return _get_best_bound_per_channel(t, bound, min_value, max_value, p_value)
|
||||
else:
|
||||
return _get_best_bound_per_tensor(t, bound, min_value, max_value, p_value)
|
||||
|
||||
|
||||
def get_min_max(
|
||||
bits: int = 8,
|
||||
unsigned: bool = False,
|
||||
use_fp: bool = False,
|
||||
) -> Tuple[float, float]:
|
||||
"""Gets the min/max range for a given number of bits.
|
||||
|
||||
Args:
|
||||
bits: Target number of bits for quantization.
|
||||
unsigned: If True compute min and max for unsigned number, else for signed.
|
||||
use_fp: in floating point.
|
||||
|
||||
Returns:
|
||||
min/max values for the provide number of bits.
|
||||
"""
|
||||
if use_fp:
|
||||
# TODO: support other fp types.
|
||||
return -448.0, 448.0
|
||||
# Calculation instead of jax.iinfo is used to support bits beside 4 and 8.
|
||||
if unsigned:
|
||||
# For unsigned 8 bits precision it is [0, 255]
|
||||
return 0, 2**bits - 1
|
||||
else:
|
||||
# For signed 8 bits precision it is [-128, 127]
|
||||
return -1 * 2 ** (bits - 1), 2 ** (bits - 1) - 1
|
||||
|
||||
|
||||
def pass_through(x: JTensor, fn: Any) -> JTensor:
|
||||
# Create an exactly-zero expression with Sterbenz lemma that has an
|
||||
# exactly-one gradient.
|
||||
return x - jax.lax.stop_gradient(x) + jax.lax.stop_gradient(fn(x))
|
||||
|
||||
|
||||
def reduce_precision(
|
||||
t: JTensor,
|
||||
contract_dims: Optional[Sequence[int]],
|
||||
need_gradient: bool = False,
|
||||
bits: int = 8,
|
||||
optimization_on_bound: bool = False,
|
||||
p_value: float = 1.0,
|
||||
percentile: float = 1.0,
|
||||
use_symmetric: bool = True,
|
||||
use_fp: bool = False,
|
||||
add_scale_eps: bool = False,
|
||||
per_channel: bool = False,
|
||||
random_rounding: bool = False,
|
||||
key: Optional[jax.Array] = None,
|
||||
) -> Tuple[JTensor, JTensor, Optional[JTensor]]:
|
||||
"""Reduce the precision of a tensor.
|
||||
|
||||
Generic for all tensors.
|
||||
|
||||
Args:
|
||||
t: Input tensor.
|
||||
contract_dims: Speficies contracting dimesnions of the input tensor.
|
||||
need_gradient: If gradient is needed out of this function.
|
||||
bits: Target number of bits.
|
||||
optimization_on_bound: If MAE bound optimizer is used.
|
||||
p_value: Exponent of the p-mean error metric. Default to 1.0 which is MAE.
|
||||
percentile: Percentile Factor to apply on the min/max range. Setting this to
|
||||
other than 1.0 disables optimization_on_bound.
|
||||
use_symmetric: If the input tensor is quantized symmetrically.
|
||||
use_fp: Use floating point.
|
||||
add_scale_eps: Add eps value or replace zero value by 1 to avoid division by
|
||||
zero.
|
||||
per_channel: use per-channel clipping optimization.
|
||||
random_rounding: round with uniform random.
|
||||
key: rng key for rounding.
|
||||
|
||||
Returns:
|
||||
A tuple of quantized tensor, quantization scale
|
||||
and quantization zero point (optional).
|
||||
"""
|
||||
min_value, max_value = get_min_max(bits, use_fp=use_fp)
|
||||
|
||||
if use_symmetric:
|
||||
bound = jnp.max(jnp.abs(t), axis=contract_dims, keepdims=True)
|
||||
scale_bound = max_value
|
||||
else:
|
||||
t_max = jnp.max(t, axis=contract_dims, keepdims=True)
|
||||
t_min = jnp.min(t, axis=contract_dims, keepdims=True)
|
||||
bound = t_max - t_min
|
||||
scale_bound = max_value - min_value
|
||||
|
||||
if percentile < 1.0:
|
||||
bound = jnp.multiply(bound, percentile)
|
||||
elif optimization_on_bound:
|
||||
bound = get_best_bound(
|
||||
t, bound, min_value, max_value, p_value, per_channel=per_channel
|
||||
)
|
||||
|
||||
scale = bound / scale_bound
|
||||
|
||||
if add_scale_eps:
|
||||
# Add epsilon to avoid divide-by-zero.
|
||||
scale = scale + jnp.finfo(t.dtype).eps
|
||||
else:
|
||||
scale = jnp.where(scale == 0.0, 1.0, scale)
|
||||
|
||||
if use_symmetric:
|
||||
zp = None
|
||||
t = jnp.divide(t, scale)
|
||||
else:
|
||||
zp = min_value - t_min / scale
|
||||
t = jnp.divide(t, scale) + zp
|
||||
zp = jnp.multiply(scale, zp)
|
||||
|
||||
if use_fp:
|
||||
# No need to round.
|
||||
t = jnp.clip(t, min_value, max_value).astype(jnp.float8_e4m3fn)
|
||||
# TODO: refactor to remove this logic.
|
||||
t = jax.lax.bitcast_convert_type(t, new_dtype=jnp.int8)
|
||||
else:
|
||||
if need_gradient:
|
||||
t = pass_through(t, jnp.round)
|
||||
t = jnp.clip(t, min_value, max_value)
|
||||
else:
|
||||
if random_rounding:
|
||||
t = t + jax.random.uniform(
|
||||
key=key, shape=t.shape, minval=-0.5, maxval=0.5
|
||||
)
|
||||
t = jnp.round(t)
|
||||
container_dtype = (
|
||||
jnp.int8 if bits <= 8 else jnp.int16 if bits <= 16 else jnp.int32
|
||||
)
|
||||
t = jnp.clip(t, min_value, max_value).astype(container_dtype)
|
||||
|
||||
return t, scale, zp
|
||||
|
||||
|
||||
def quantize_tensor(
|
||||
var: np.ndarray,
|
||||
axis: List[int],
|
||||
factor: float = 1.0,
|
||||
sym: bool = True,
|
||||
number_bits: int = 8,
|
||||
use_fp: bool = False,
|
||||
add_scale_eps: bool = False,
|
||||
optimization_on_bound: bool = False,
|
||||
p_value: float = 1.0,
|
||||
per_channel: bool = False,
|
||||
block_size: int = 0,
|
||||
) -> Union[
|
||||
Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, np.ndarray]
|
||||
]:
|
||||
"""Quantize a tensor.
|
||||
|
||||
Args:
|
||||
var: The variable to be quantized.
|
||||
axis: The axis along which variable will be quantized.
|
||||
factor: The clipping factor.
|
||||
sym: Symmetric or asymmetric quantize the variable.
|
||||
number_bits: Number of bits for quantized value.
|
||||
use_fp: do fp with number of bits (i.e. fp8)
|
||||
add_scale_eps: add epsilon to scale to avoid division by zero, else it will
|
||||
replace zero scale by 1.
|
||||
optimization_on_bound: If p-mean bound optimizer is used.
|
||||
p_value: Exponent of the p-mean error metric. Default to 1.0 which is MAE.
|
||||
per_channel: use per-channel clipping optimization.
|
||||
block_size: block size for sub-channel quantization. Defaults to 0, which
|
||||
means off.
|
||||
|
||||
Returns:
|
||||
Quantized tensors, along with scales and zero point.
|
||||
"""
|
||||
# TODO: support jnp.float8_e5m2
|
||||
assert number_bits == 8 or number_bits == 4
|
||||
jnp_var = jnp.asarray(var)
|
||||
# When using sub-channel, the contracting dim is split into a sub-channel
|
||||
# dim followed by the block dim. Therefore the contracting dim
|
||||
# (quantize_axis) should increment by one, and the corresponding pack_dim
|
||||
# should also increment by one.
|
||||
if block_size > 0:
|
||||
shape = list(jnp_var.shape)
|
||||
assert len(axis) == 1, 'Only support 1D sub-channel quantization'
|
||||
sub_channels, rem = divmod(shape[axis[0]], block_size)
|
||||
assert rem == 0
|
||||
shape.insert(axis[0], sub_channels)
|
||||
axis[0] += 1
|
||||
shape[axis[0]] = block_size
|
||||
jnp_var = jnp.reshape(jnp_var, shape)
|
||||
|
||||
qvar, scale, zp = reduce_precision(
|
||||
jnp_var,
|
||||
contract_dims=axis,
|
||||
need_gradient=False,
|
||||
bits=number_bits,
|
||||
optimization_on_bound=optimization_on_bound,
|
||||
percentile=factor,
|
||||
use_symmetric=sym,
|
||||
use_fp=use_fp,
|
||||
add_scale_eps=add_scale_eps,
|
||||
p_value=p_value,
|
||||
per_channel=per_channel,
|
||||
)
|
||||
if sym:
|
||||
return np.array(qvar), np.array(jnp.squeeze(scale, axis=axis)) # pytype: disable=wrong-arg-types # jnp-type
|
||||
else:
|
||||
return (
|
||||
np.array(qvar),
|
||||
# CAVEAT: the following squeezes should squeeze along the quantization
|
||||
# axis only.
|
||||
np.array(jnp.squeeze(scale)),
|
||||
np.array(jnp.squeeze(zp)),
|
||||
)
|
||||
|
||||
|
||||
def pack_4bit(
|
||||
x: np.ndarray, pack_dim: int, packed_dtype: jnp.dtype = jnp.int32
|
||||
) -> np.ndarray:
|
||||
"""Pack int8 or uint8 tensor where its values are actually int4 or uint4, to int32 or int8 nibble format along pack_dim.
|
||||
|
||||
Args:
|
||||
x: Original int8 or uint8 tensor to pack.
|
||||
pack_dim: Dimension to pack along. x.shape[pack_dim] must be divisible by 8,
|
||||
when packed_dtype is int32 and divisible by 2 when target_type is int8.
|
||||
Also pack_dim must be < x.ndim - 1.
|
||||
packed_dtype: Target type to pack to, int32 or int8.
|
||||
|
||||
Returns:
|
||||
int32 or int8 packed tensor where the pack_dim size is dividened by 8
|
||||
from the original tensor x.
|
||||
"""
|
||||
x = jnp.asarray(x)
|
||||
if packed_dtype == jnp.int8 and x.dtype == jnp.uint8:
|
||||
# It doesn't make sense to pack uint8 numbers into int4 as we'll
|
||||
# the range overlap between uint8 and int4 is [0..7].
|
||||
raise ValueError(
|
||||
'only int8 input dtype is supported when packing into int8. '
|
||||
f'Given {x.dtype}'
|
||||
)
|
||||
|
||||
if x.dtype != jnp.int8 and x.dtype != jnp.uint8:
|
||||
raise ValueError(
|
||||
f'input dtype must be either int8 or uint8. Given {x.dtype}'
|
||||
)
|
||||
if pack_dim >= x.ndim - 1:
|
||||
raise ValueError(
|
||||
f'pack_dim must be < input ndim - 1. input shape {x.shape} and pack_dim'
|
||||
f' {pack_dim}'
|
||||
)
|
||||
if packed_dtype != jnp.int32 and packed_dtype != jnp.int8:
|
||||
raise ValueError(
|
||||
f'packed_dtype must be either int32 or int8. Given {packed_dtype}'
|
||||
)
|
||||
if packed_dtype == jnp.int32 and x.shape[pack_dim] % 8 != 0:
|
||||
raise ValueError(
|
||||
'input shape[pack_dim] must be divisible by 8 when target_type '
|
||||
f'is int32. Given shape {x.shape}'
|
||||
)
|
||||
if packed_dtype == jnp.int8 and x.shape[pack_dim] % 2 != 0:
|
||||
raise ValueError(
|
||||
'input shape[pack_dim] must be divisible by 2 when target_type '
|
||||
f'is int8. Given shape {x.shape}'
|
||||
)
|
||||
|
||||
int4s_per_packed_type = 8 if packed_dtype == jnp.int32 else 2
|
||||
|
||||
rep_shape = list(x.shape)
|
||||
rep_shape.insert(pack_dim + 1, int4s_per_packed_type)
|
||||
rep_shape[pack_dim] //= int4s_per_packed_type
|
||||
|
||||
shifts = lax.broadcasted_iota(packed_dtype, rep_shape, pack_dim + 1)
|
||||
shifts <<= 2
|
||||
|
||||
# Promote x to packed_dtype
|
||||
x = x & jnp.array(0x0F, packed_dtype)
|
||||
x = lax.reshape(x, rep_shape)
|
||||
x = x << shifts
|
||||
x = lax.reduce(x, jnp.array(0x0, packed_dtype), lax.add, [pack_dim + 1])
|
||||
return np.asarray(x)
|
||||
|
||||
|
||||
def update_to_uint4(
|
||||
qx: np.ndarray, scale: np.ndarray, zp: Optional[np.ndarray] = None
|
||||
):
|
||||
"""Updates the quantized weights from int4 to uint4.
|
||||
|
||||
This is a conversion function designed for XNNPack as it expects the 4-bit
|
||||
quantized weight to be represented differently from the original Pax setting.
|
||||
Specifically, the differences are:
|
||||
1) The dynamic range of weight values: int4 (Pax) vs. uint4 (XNNPack).
|
||||
2) The dynamic range of zero-point: float (Pax) vs. uint4 (XNNPack).
|
||||
3) The number of zero-point: per-channel (Pax) vs. per-tensor (XNNPack).
|
||||
|
||||
Args:
|
||||
qx: np.array of shape [..., channel], which is the quantized weight values
|
||||
from Pax in the shape of. The values are in the dynamic range of int4 but
|
||||
are hosted as int8 type. Note that if the first dimension is 3, it means
|
||||
the qkv matrices are concatenated together and should be treated
|
||||
differently.
|
||||
scale: np.array of shape [1(3), channel] as np.float type, which are the
|
||||
scaling factors for dequantization per channel.
|
||||
zp: (optional) np.array of shape [1 (or 3), channel] as np.float type, which
|
||||
are the zero points for dequantization per channel.
|
||||
|
||||
Returns:
|
||||
A tuple (qx, scale, zp):
|
||||
qx: The updated np.array of shape [..., channel] as np.int8 type with
|
||||
updated dynamic range as uint4 (with 8 as the default zero points).
|
||||
scale: Same as the input scale.
|
||||
zp: (optional) np.array of shape [1 (or 3)] as np.int8 type with the
|
||||
updated zero point values in the dynamic range as uint4.
|
||||
"""
|
||||
if qx.dtype != np.int8 or ('float' not in str(scale.dtype)):
|
||||
raise ValueError(
|
||||
'Unexpected dtype qx:' + str(qx.dtype) + ' scale:' + str(scale.dtype)
|
||||
)
|
||||
|
||||
scale = scale.astype(np.float32)
|
||||
|
||||
def get_new_zp(old_zp):
|
||||
new_zp = old_zp / (scale + np.finfo(np.float32).eps)
|
||||
per_tensor_zp = np.mean(new_zp)
|
||||
per_tensor_zp = per_tensor_zp.astype(np.int8) + _UINT4_ZP
|
||||
return per_tensor_zp
|
||||
|
||||
if zp is not None:
|
||||
if qx.shape[0] == 3:
|
||||
per_tensor_zp = np.stack([get_new_zp(szp) for szp in zp], axis=0)
|
||||
else:
|
||||
per_tensor_zp = get_new_zp(zp)
|
||||
else:
|
||||
per_tensor_zp = (
|
||||
_UINT4_ZP * np.ones(shape=(3)) if qx.shape[0] == 3 else _UINT4_ZP
|
||||
)
|
||||
|
||||
qx = qx + _UINT4_ZP
|
||||
return qx, scale, np.array(per_tensor_zp, dtype=np.int32)
|
||||
+259
@@ -0,0 +1,259 @@
|
||||
# Copyright 2024 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for quantization_util."""
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from mediapipe.tasks.python.genai.converter import quantization_util
|
||||
|
||||
|
||||
_dtype = lambda x: getattr(x, 'dtype', None) or np.asarray(x).dtype
|
||||
|
||||
|
||||
class TestCase(absltest.TestCase):
|
||||
|
||||
def assertAllClose(
|
||||
self, x, y, check_dtypes=True, rtol=1e-5, atol=1e-5, **kwargs
|
||||
):
|
||||
"""Wrapper for np.testing.assert_allclose()."""
|
||||
x = np.asarray(x)
|
||||
y = np.asarray(y)
|
||||
if check_dtypes:
|
||||
self.assertDtypesMatch(x, y)
|
||||
x = x.astype(np.float32) if x.dtype == jnp.bfloat16 else x
|
||||
y = y.astype(np.float32) if y.dtype == jnp.bfloat16 else y
|
||||
np.testing.assert_allclose(x, y, rtol=rtol, atol=atol, **kwargs)
|
||||
|
||||
def assertDtypesMatch(self, x, y):
|
||||
self.assertEqual(
|
||||
jax.dtypes.canonicalize_dtype(_dtype(x)),
|
||||
jax.dtypes.canonicalize_dtype(_dtype(y)),
|
||||
)
|
||||
|
||||
|
||||
class Quantize8BTest(TestCase):
|
||||
|
||||
def test_quantize_symmetric(self):
|
||||
inputs = np.array([[1.2, 3.1, 5.5, 2.9], [0.2, -1.5, 3.3, 4.0]])
|
||||
qx, scale = quantization_util.quantize_tensor(inputs, axis=[1])
|
||||
|
||||
self.assertAllClose(
|
||||
qx, np.array([[28, 72, 127, 67], [6, -48, 105, 127]], dtype=np.int8)
|
||||
)
|
||||
self.assertAllClose(
|
||||
scale, np.array([0.04330709, 0.03149606], dtype=np.float32)
|
||||
)
|
||||
|
||||
def test_quantize_symmetric_with_dimension_size_one_unquantized(self):
|
||||
# inputs shape: (2, 1, 4), quantization axis 2.
|
||||
inputs = np.array([[[1.2, 3.1, 5.5, 2.9]], [[0.2, -1.5, 3.3, 4.0]]])
|
||||
qx, scale = quantization_util.quantize_tensor(inputs, axis=[2])
|
||||
|
||||
self.assertAllClose(
|
||||
qx, np.array([[[28, 72, 127, 67]], [[6, -48, 105, 127]]], dtype=np.int8)
|
||||
)
|
||||
# expected scale shape: (2, 1)
|
||||
self.assertAllClose(
|
||||
scale, np.array([[0.04330709], [0.03149606]], dtype=np.float32)
|
||||
)
|
||||
|
||||
def test_quantize_asymmetric(self):
|
||||
inputs = np.array([[1.2, 3.1, 5.5, 2.9], [0.2, -1.5, 3.3, 4.0]])
|
||||
qx, scale, zp = quantization_util.quantize_tensor(
|
||||
inputs, axis=[1], sym=False
|
||||
)
|
||||
|
||||
self.assertAllClose(
|
||||
qx,
|
||||
np.array([[-128, -15, 127, -27], [-49, -128, 95, 127]], dtype=np.int8),
|
||||
)
|
||||
self.assertAllClose(scale, np.array([0.016863, 0.021569], dtype=np.float32))
|
||||
self.assertAllClose(zp, np.array([-3.358431, -1.260784], dtype=np.float32))
|
||||
|
||||
|
||||
class Quantize8BFPTest(TestCase):
|
||||
|
||||
def test_quantize_symmetric(self):
|
||||
inputs = np.array([[1.0, 2.0, 5.5, 2.9], [0.02, -0.01, 3.3, 4.0]])
|
||||
qx, scale = quantization_util.quantize_tensor(inputs, axis=[1], use_fp=True)
|
||||
|
||||
self.assertAllClose(
|
||||
qx,
|
||||
np.array([[106, 114, 126, 119], [65, -71, 124, 126]], dtype=np.int8),
|
||||
)
|
||||
self.assertAllClose(
|
||||
scale, np.array([0.01227679, 0.00892857], dtype=np.float32)
|
||||
)
|
||||
|
||||
def test_quantize_symmetric_with_dimension_size_one_unquantized(self):
|
||||
# inputs shape: (2, 1, 4), quantization axis 2.
|
||||
inputs = np.array([[[1.0, 2.0, 5.5, 2.9]], [[0.02, -0.01, 3.3, 4.0]]])
|
||||
qx, scale = quantization_util.quantize_tensor(inputs, axis=[2], use_fp=True)
|
||||
|
||||
self.assertAllClose(
|
||||
qx,
|
||||
np.array(
|
||||
[[[106, 114, 126, 119]], [[65, -71, 124, 126]]], dtype=np.int8
|
||||
),
|
||||
)
|
||||
# expected scale shape: (2, 1)
|
||||
self.assertAllClose(
|
||||
scale, np.array([[0.01227679], [0.00892857]], dtype=np.float32)
|
||||
)
|
||||
|
||||
def test_quantize_asymmetric(self):
|
||||
inputs = np.array([[-1.0, -2.0, -2.01, 2.9], [0.02, -0.15, 3.3, 4.0]])
|
||||
qx, scale, zp = quantization_util.quantize_tensor(
|
||||
inputs, axis=[1], sym=False, use_fp=True
|
||||
)
|
||||
|
||||
self.assertAllClose(
|
||||
qx,
|
||||
np.array([[-8, -2, -2, 126], [-3, -2, 121, 126]], dtype=np.int8),
|
||||
)
|
||||
self.assertAllClose(
|
||||
scale, np.array([0.00547991, 0.0046317], dtype=np.float32)
|
||||
)
|
||||
self.assertAllClose(
|
||||
zp, np.array([-0.4449999, -1.9250002], dtype=np.float32)
|
||||
)
|
||||
|
||||
def test_quantize_add_scale_eps(self):
|
||||
inputs = np.array([[0.0, 0.0, 0.0, 0.0], [-4.0, -4.0, -4.0, -4.0]])
|
||||
_, scale, _ = quantization_util.quantize_tensor(
|
||||
inputs, axis=[1], sym=False, use_fp=True, add_scale_eps=True
|
||||
)
|
||||
self.assertAllClose(
|
||||
scale, np.array([np.finfo(np.float32).eps, np.finfo(np.float32).eps])
|
||||
)
|
||||
_, scale, _ = quantization_util.quantize_tensor(
|
||||
inputs, axis=[1], sym=False, use_fp=True, add_scale_eps=False
|
||||
)
|
||||
self.assertAllClose(scale, np.array([1.0, 1.0]))
|
||||
|
||||
|
||||
class Quantize4BTest(TestCase):
|
||||
|
||||
def test_quantize_symmetric(self):
|
||||
inputs = np.array([[1.2, 3.1, 5.5, 2.9], [0.2, -1.5, 3.3, 4.0]])
|
||||
qx, scale = quantization_util.quantize_tensor(
|
||||
inputs, axis=[1], number_bits=4
|
||||
)
|
||||
self.assertAllClose(
|
||||
qx, np.array([[2, 4, 7, 4], [0, -3, 6, 7]], dtype=np.int8)
|
||||
)
|
||||
self.assertAllClose(
|
||||
scale, np.array([0.78571427, 0.5714286], dtype=np.float32)
|
||||
)
|
||||
|
||||
def test_quantize_symmetric_with_dimension_size_one_unquantized(self):
|
||||
# inputs shape: (2, 1, 4), quantization axis 2.
|
||||
inputs = np.array([[[1.2, 3.1, 5.5, 2.9]], [[0.2, -1.5, 3.3, 4.0]]])
|
||||
qx, scale = quantization_util.quantize_tensor(
|
||||
inputs, axis=[2], number_bits=4
|
||||
)
|
||||
|
||||
self.assertAllClose(
|
||||
qx, np.array([[[2, 4, 7, 4]], [[0, -3, 6, 7]]], dtype=np.int8)
|
||||
)
|
||||
# expected scale shape: (2, 1)
|
||||
self.assertAllClose(
|
||||
scale, np.array([[0.78571427], [0.5714286]], dtype=np.float32)
|
||||
)
|
||||
|
||||
def test_quantize_asymmetric(self):
|
||||
inputs = np.array([[1.2, 3.1, 5.5, 2.9], [0.2, -1.5, 3.3, 4.0]])
|
||||
qx, scale, zp = quantization_util.quantize_tensor(
|
||||
inputs, axis=[1], sym=False, number_bits=4
|
||||
)
|
||||
|
||||
self.assertAllClose(
|
||||
qx,
|
||||
np.array([[-8, -1, 7, -2], [-3, -8, 5, 7]], dtype=np.int8),
|
||||
)
|
||||
self.assertAllClose(
|
||||
scale, np.array([0.2866667, 0.36666667], dtype=np.float32)
|
||||
)
|
||||
self.assertAllClose(
|
||||
zp, np.array([-3.4933336, -1.4333334], dtype=np.float32)
|
||||
)
|
||||
|
||||
|
||||
class QuantizationUtilTest(TestCase):
|
||||
|
||||
def test_update_to_uint4_sym(self):
|
||||
inputs = np.array([[1.2, 3.1, -5.5, 2.9], [0.2, -1.5, 3.3, 4.0]])
|
||||
qx, scale = quantization_util.quantize_tensor(
|
||||
inputs, axis=[1], sym=True, number_bits=4
|
||||
)
|
||||
dequant_from_int4 = qx * np.expand_dims(scale, -1)
|
||||
qx_n, scale_n, zp_n = quantization_util.update_to_uint4(qx, scale)
|
||||
self.assertEmpty(zp_n.shape) # A scalar numpy array.
|
||||
dequant_from_uint4 = np.expand_dims(scale_n, -1) * (qx_n - zp_n)
|
||||
np.testing.assert_allclose(dequant_from_int4, dequant_from_uint4)
|
||||
|
||||
def test_update_to_uint4_sym_combined(self):
|
||||
inputs = np.array(
|
||||
[[-1.2, 3.5, -6.2, 1.7], [1.2, 3.1, -5.5, 2.9], [0.2, -1.5, 3.3, 4.0]]
|
||||
)
|
||||
qx, scale = quantization_util.quantize_tensor(
|
||||
inputs, axis=[1], sym=True, number_bits=4
|
||||
)
|
||||
dequant_from_int4 = qx * np.expand_dims(scale, -1)
|
||||
qx_n, scale_n, zp_n = quantization_util.update_to_uint4(qx, scale)
|
||||
self.assertEqual(zp_n.shape[0], 3)
|
||||
dequant_from_uint4 = np.expand_dims(scale_n, -1) * (
|
||||
qx_n - np.expand_dims(zp_n, -1)
|
||||
)
|
||||
np.testing.assert_allclose(dequant_from_int4, dequant_from_uint4)
|
||||
|
||||
def test_update_to_uint4_asym(self):
|
||||
inputs = np.array([[1.0, 8.0, -3.0, 2.0], [-3.0, 2.0, 1.0, 8.0]])
|
||||
qx, scale, zp = quantization_util.quantize_tensor(
|
||||
inputs, axis=[1], sym=False, number_bits=4
|
||||
)
|
||||
qx_n, scale_n, zp_n = quantization_util.update_to_uint4(qx, scale, zp)
|
||||
expected_dequant = np.array([
|
||||
[0.0, 7.333333, -3.666667, 1.466667],
|
||||
[-3.666667, 1.466667, 0.0, 7.333333],
|
||||
])
|
||||
dequant_from_uint4 = np.expand_dims(scale_n, -1) * (qx_n - zp_n)
|
||||
np.testing.assert_allclose(dequant_from_uint4, expected_dequant, rtol=1e-05)
|
||||
|
||||
def test_update_to_uint4_asym_combined(self):
|
||||
inputs = np.array(
|
||||
[[1.0, 8.0, -3.0, 2.0], [-3.0, 2.0, 1.0, 8.0], [2.0, 1.0, 8.0, -3.0]]
|
||||
)
|
||||
qx, scale, zp = quantization_util.quantize_tensor(
|
||||
inputs, axis=[1], sym=False, number_bits=4
|
||||
)
|
||||
qx_n, scale_n, zp_n = quantization_util.update_to_uint4(qx, scale, zp)
|
||||
self.assertEqual(zp_n.shape[0], 3)
|
||||
expected_dequant = np.array([
|
||||
[0.0, 7.333333, -3.666667, 1.466667],
|
||||
[-3.666667, 1.466667, 0.0, 7.333333],
|
||||
[1.466667, 0.0, 7.333333, -3.666667],
|
||||
])
|
||||
dequant_from_uint4 = np.expand_dims(scale_n, -1) * (
|
||||
qx_n - np.expand_dims(zp_n, -1)
|
||||
)
|
||||
np.testing.assert_allclose(dequant_from_uint4, expected_dequant, rtol=1e-05)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
+554
@@ -0,0 +1,554 @@
|
||||
# Copyright 2024 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""CkptLoader implementation for loading the Safetensors."""
|
||||
|
||||
import array
|
||||
from typing import Iterator
|
||||
import enum
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from mediapipe.tasks.python.genai.converter import converter_base
|
||||
|
||||
|
||||
DTYPE_MAP = {
|
||||
"F16": torch.float16,
|
||||
"BF16": torch.bfloat16,
|
||||
"F32": torch.float32,
|
||||
}
|
||||
|
||||
|
||||
class _SafetensorsShardReader:
|
||||
"""Reads a single safetensors shard."""
|
||||
|
||||
_HEAD_BYTES = 8
|
||||
|
||||
def __init__(self, shard_path: str):
|
||||
self._shard_path = shard_path
|
||||
if not os.path.exists(self._shard_path):
|
||||
raise ValueError(f"{self._shard_path} does not exists.")
|
||||
with open(self._shard_path, "rb") as f:
|
||||
head_bytes = f.read(self._HEAD_BYTES)
|
||||
metadata_bytes_num = np.frombuffer(head_bytes, dtype=np.uint64)[0]
|
||||
metadata_bytes = f.read(metadata_bytes_num)
|
||||
self.layers_info = json.loads(metadata_bytes)
|
||||
self.metadata_bytes_num = metadata_bytes_num
|
||||
|
||||
def read_tensor_as_numpy(self, tensor_name) -> np.ndarray:
|
||||
"""Reads a tensor from the model file as a numpy array with np.float32 type."""
|
||||
tensor_info = self.layers_info[tensor_name]
|
||||
with open(self._shard_path, "rb") as f:
|
||||
shape = tensor_info["shape"]
|
||||
dtype = tensor_info["dtype"]
|
||||
if dtype not in DTYPE_MAP:
|
||||
raise ValueError(f"{dtype} is not supported.")
|
||||
data_offsets = tensor_info["data_offsets"]
|
||||
f.seek(int(self._HEAD_BYTES + self.metadata_bytes_num + data_offsets[0]))
|
||||
tensor_bytes = f.read(data_offsets[1] - data_offsets[0])
|
||||
raw_tensor = torch.frombuffer(
|
||||
array.array("b", tensor_bytes), dtype=DTYPE_MAP[dtype]
|
||||
).reshape(shape)
|
||||
return raw_tensor.float().t().contiguous().numpy()
|
||||
|
||||
def get_tensor_names(self) -> List[str]:
|
||||
names = list(self.layers_info.keys())
|
||||
if "__metadata__" in names:
|
||||
names.remove("__metadata__")
|
||||
return names
|
||||
|
||||
|
||||
class _SafetensorsReader:
|
||||
"""Reads all the safetensors shards."""
|
||||
|
||||
def __init__(self, ckpt_path: str):
|
||||
shards = []
|
||||
if os.path.isdir(ckpt_path):
|
||||
# Read all safetensors files within checkpoint
|
||||
for shard_path in glob.glob(os.path.join(ckpt_path, "*.safetensors")):
|
||||
shards.append(_SafetensorsShardReader(shard_path))
|
||||
else:
|
||||
# Assume the ckpt_path is a file or a file pattern to match.
|
||||
for shard_path in glob.glob(ckpt_path):
|
||||
shards.append(_SafetensorsShardReader(shard_path))
|
||||
assert shards is not None
|
||||
|
||||
self._ckpt_path = ckpt_path
|
||||
self._tensors_map = {}
|
||||
for shard in shards:
|
||||
tensor_names = shard.get_tensor_names()
|
||||
for tensor_name in tensor_names:
|
||||
if tensor_name in self._tensors_map:
|
||||
raise ValueError(f"Duplicate tensor name: {tensor_name}")
|
||||
self._tensors_map[tensor_name] = shard
|
||||
|
||||
def get_tensor_names(self) -> List[str]:
|
||||
return list(self._tensors_map.keys())
|
||||
|
||||
def read_tensor_as_numpy(self, tensor_name: str) -> np.ndarray:
|
||||
return self._tensors_map[tensor_name].read_tensor_as_numpy(tensor_name)
|
||||
|
||||
|
||||
class LayerType(enum.Enum):
|
||||
"""Enum for layer type."""
|
||||
|
||||
NONE = 0
|
||||
ATTENTION = 1 # Layer is part of the attention module.
|
||||
FEEDFORWARD = 2 # Layer is part of the feedforward module in the Transformer.
|
||||
EMBEDDING = 3 # Layer is the embedding lookup or final projection layer.
|
||||
LAYER_NORM = (
|
||||
4 # Layer is layer normalization before and after attention layer.
|
||||
)
|
||||
LORA = 5 # Layer is LoRA weights augmented on the base model layers.
|
||||
|
||||
@classmethod
|
||||
def get_layer_type(cls, layer_name: str):
|
||||
"""Gets the layer type of the given layer name."""
|
||||
ffn_layers = [
|
||||
"mlp",
|
||||
]
|
||||
attn_layers = [
|
||||
"self_attn",
|
||||
]
|
||||
emb_layers = [
|
||||
"embed_tokens",
|
||||
"lm_head",
|
||||
]
|
||||
layer_norms = [
|
||||
"input_layernorm",
|
||||
"post_attention_layernorm",
|
||||
"final_layernorm",
|
||||
"model.norm.weight",
|
||||
]
|
||||
lora_layers = ["lora"]
|
||||
if any(sub_name in layer_name for sub_name in lora_layers):
|
||||
return LayerType.LORA
|
||||
if any(sub_name in layer_name for sub_name in attn_layers):
|
||||
return LayerType.ATTENTION
|
||||
if any(sub_name in layer_name for sub_name in ffn_layers):
|
||||
return LayerType.FEEDFORWARD
|
||||
if any(sub_name in layer_name for sub_name in emb_layers):
|
||||
return LayerType.EMBEDDING
|
||||
if any(sub_name in layer_name for sub_name in layer_norms):
|
||||
return LayerType.LAYER_NORM
|
||||
else:
|
||||
return LayerType.NONE
|
||||
|
||||
|
||||
class StablelmMapper(converter_base.LayerActionMapperBase):
|
||||
"""LayerActionMapper for handling the StableLM model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_symmetric: bool,
|
||||
attention_quant_bits: int,
|
||||
feedforward_quant_bits: int,
|
||||
embedding_quant_bits: int,
|
||||
backend: str,
|
||||
reader: _SafetensorsReader,
|
||||
):
|
||||
super().__init__(
|
||||
is_symmetric=is_symmetric,
|
||||
attention_quant_bits=attention_quant_bits,
|
||||
feedforward_quant_bits=feedforward_quant_bits,
|
||||
embedding_quant_bits=embedding_quant_bits,
|
||||
backend=backend,
|
||||
)
|
||||
self._reader = reader
|
||||
|
||||
def map_to_actions(
|
||||
self, layer_name: str
|
||||
) -> Optional[List[converter_base.QuantizationAction]]:
|
||||
"""Map the given layer name to actions."""
|
||||
tensor_value = self._reader.read_tensor_as_numpy(layer_name)
|
||||
quantize_axis = None
|
||||
quantize_bits = None
|
||||
layer_type = LayerType.get_layer_type(layer_name)
|
||||
|
||||
if layer_type != LayerType.LAYER_NORM and layer_name.endswith(".weight"):
|
||||
quantize_axis = [0]
|
||||
if layer_type == LayerType.FEEDFORWARD:
|
||||
quantize_bits = self._feedforward_quant_bits
|
||||
elif layer_type == LayerType.ATTENTION:
|
||||
quantize_bits = self._attention_quant_bits
|
||||
if self._backend == "cpu" and ".o_proj." in layer_name:
|
||||
tensor_value = np.transpose(tensor_value)
|
||||
quantize_axis = [1]
|
||||
elif layer_type == LayerType.EMBEDDING:
|
||||
quantize_bits = self._embedding_quant_bits
|
||||
if self._backend == "cpu" and ".embed_tokens." in layer_name:
|
||||
tensor_value = np.transpose(tensor_value)
|
||||
quantize_axis = [1]
|
||||
target_name = self.update_target_name(layer_name)
|
||||
|
||||
actions = [
|
||||
converter_base.QuantizationAction(
|
||||
tensor_name=layer_name,
|
||||
tensor_value=tensor_value,
|
||||
target_name=target_name,
|
||||
quantize_axis=quantize_axis,
|
||||
quantize_bits=quantize_bits,
|
||||
pack_dim=0,
|
||||
)
|
||||
]
|
||||
return actions
|
||||
|
||||
def update_target_name(self, target_name: str) -> str:
|
||||
"""Updates the target name to match the tensor name convention."""
|
||||
target_name = target_name.replace(
|
||||
"model.layers.", "params.lm.transformer.x_layers_"
|
||||
)
|
||||
target_name = target_name.replace("mlp.up_proj", "ff_layer.ffn_layer1")
|
||||
target_name = target_name.replace("mlp.down_proj", "ff_layer.ffn_layer2")
|
||||
target_name = target_name.replace(
|
||||
"mlp.gate_proj", "ff_layer.ffn_layer1_gate"
|
||||
)
|
||||
target_name = target_name.replace("input_layernorm", "pre_layer_norm")
|
||||
target_name = target_name.replace(
|
||||
"pre_layer_norm.weight", "pre_layer_norm.scale"
|
||||
)
|
||||
if self._backend == "cpu":
|
||||
target_name = target_name.replace(
|
||||
"post_attention_layernorm", "ff_layer.pre_layer_norm"
|
||||
)
|
||||
target_name = target_name.replace(
|
||||
"ff_layer.pre_layer_norm.weight", "ff_layer.pre_layer_norm.scale"
|
||||
)
|
||||
else:
|
||||
target_name = target_name.replace(
|
||||
"post_attention_layernorm", "post_layer_norm"
|
||||
)
|
||||
target_name = target_name.replace(
|
||||
"post_layer_norm.weight", "post_layer_norm.scale"
|
||||
)
|
||||
target_name = target_name.replace("self_attn.q_proj", "self_attention.q")
|
||||
target_name = target_name.replace("self_attn.k_proj", "self_attention.k")
|
||||
target_name = target_name.replace("self_attn.v_proj", "self_attention.v")
|
||||
target_name = target_name.replace("self_attn.o_proj", "self_attention.post")
|
||||
target_name = target_name.replace(
|
||||
"model.embed_tokens", "params.lm.token_embedding"
|
||||
)
|
||||
target_name = target_name.replace("model.norm", "params.lm.final_ln")
|
||||
target_name = target_name.replace("final_ln.weight", "final_ln.scale")
|
||||
target_name = target_name.replace("lm_head", "params.lm.softmax.logits_ffn")
|
||||
target_name = target_name.replace(".weight", ".w")
|
||||
|
||||
return target_name
|
||||
|
||||
|
||||
class PhiMapper(converter_base.LayerActionMapperBase):
|
||||
"""LayerActionMapper for handling the Phi model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_symmetric: bool,
|
||||
attention_quant_bits: int,
|
||||
feedforward_quant_bits: int,
|
||||
embedding_quant_bits: int,
|
||||
backend: str,
|
||||
reader: _SafetensorsReader,
|
||||
):
|
||||
super().__init__(
|
||||
is_symmetric=is_symmetric,
|
||||
attention_quant_bits=attention_quant_bits,
|
||||
feedforward_quant_bits=feedforward_quant_bits,
|
||||
embedding_quant_bits=embedding_quant_bits,
|
||||
backend=backend,
|
||||
)
|
||||
self._reader = reader
|
||||
|
||||
def map_to_actions(
|
||||
self, layer_name: str
|
||||
) -> Optional[List[converter_base.QuantizationAction]]:
|
||||
"""Map the given layer name to actions."""
|
||||
tensor_value = self._reader.read_tensor_as_numpy(layer_name)
|
||||
quantize_axis = None
|
||||
quantize_bits = None
|
||||
layer_type = LayerType.get_layer_type(layer_name)
|
||||
|
||||
if (
|
||||
layer_type != LayerType.LAYER_NORM
|
||||
and layer_name.endswith(".weight")
|
||||
and layer_type != LayerType.LORA
|
||||
):
|
||||
quantize_axis = [0]
|
||||
if layer_type == LayerType.FEEDFORWARD:
|
||||
quantize_bits = self._feedforward_quant_bits
|
||||
elif layer_type == LayerType.ATTENTION:
|
||||
quantize_bits = self._attention_quant_bits
|
||||
if self._backend == "cpu" and ".dense." in layer_name:
|
||||
tensor_value = np.transpose(tensor_value)
|
||||
quantize_axis = [1]
|
||||
elif layer_type == LayerType.EMBEDDING:
|
||||
quantize_bits = self._embedding_quant_bits
|
||||
if self._backend == "cpu" and ".embed_tokens." in layer_name:
|
||||
tensor_value = np.transpose(tensor_value)
|
||||
quantize_axis = [1]
|
||||
target_name = self.update_target_name(layer_name)
|
||||
|
||||
actions = [
|
||||
converter_base.QuantizationAction(
|
||||
tensor_name=layer_name,
|
||||
tensor_value=tensor_value,
|
||||
target_name=target_name,
|
||||
quantize_axis=quantize_axis,
|
||||
quantize_bits=quantize_bits,
|
||||
pack_dim=0,
|
||||
)
|
||||
]
|
||||
return actions
|
||||
|
||||
def update_target_name(self, target_name: str) -> str:
|
||||
"""Updates the target name to match the tensor name convention."""
|
||||
target_name = target_name.replace("base_model.model.", "")
|
||||
target_name = target_name.replace(
|
||||
"model.layers.", "params.lm.transformer.x_layers_"
|
||||
)
|
||||
|
||||
layer_type = LayerType.get_layer_type(target_name)
|
||||
if layer_type == LayerType.FEEDFORWARD:
|
||||
target_name = target_name.replace(".weight", ".linear.w")
|
||||
target_name = target_name.replace(".bias", ".bias.b")
|
||||
target_name = target_name.replace("mlp.fc1", "ff_layer.ffn_layer1")
|
||||
target_name = target_name.replace("mlp.fc2", "ff_layer.ffn_layer2")
|
||||
|
||||
elif layer_type == LayerType.ATTENTION:
|
||||
target_name = target_name.replace(".weight", ".linear.w")
|
||||
target_name = target_name.replace(".bias", ".bias.b")
|
||||
target_name = target_name.replace("self_attn.q_proj", "self_attention.q")
|
||||
target_name = target_name.replace("self_attn.k_proj", "self_attention.k")
|
||||
target_name = target_name.replace("self_attn.v_proj", "self_attention.v")
|
||||
target_name = target_name.replace(
|
||||
"self_attn.dense", "self_attention.post"
|
||||
)
|
||||
elif layer_type == LayerType.EMBEDDING:
|
||||
target_name = target_name.replace(
|
||||
"model.embed_tokens", "params.lm.token_embedding"
|
||||
)
|
||||
target_name = target_name.replace(
|
||||
"lm_head", "params.lm.softmax.logits_ffn"
|
||||
)
|
||||
target_name = target_name.replace(
|
||||
"logits_ffn.weight", "logits_ffn.linear.w"
|
||||
)
|
||||
target_name = target_name.replace("logits_ffn.bias", "logits_ffn.bias.b")
|
||||
elif layer_type == LayerType.LAYER_NORM:
|
||||
target_name = target_name.replace("input_layernorm", "pre_layer_norm")
|
||||
target_name = target_name.replace(
|
||||
"pre_layer_norm.weight", "pre_layer_norm.scale"
|
||||
)
|
||||
target_name = target_name.replace(
|
||||
"model.final_layernorm", "params.lm.final_ln"
|
||||
)
|
||||
target_name = target_name.replace("final_ln.weight", "final_ln.scale")
|
||||
target_name = target_name.replace(".weight", ".w")
|
||||
|
||||
# For LoRA weights
|
||||
if "post" in target_name:
|
||||
target_name = target_name.replace("lora_A.linear.w", "w_prime_right")
|
||||
target_name = target_name.replace("lora_B.linear.w", "w_prime_left")
|
||||
else:
|
||||
target_name = target_name.replace("lora_A.linear.w", "w_prime_left")
|
||||
target_name = target_name.replace("lora_B.linear.w", "w_prime_right")
|
||||
|
||||
return target_name
|
||||
|
||||
|
||||
class GemmaMapper(converter_base.LayerActionMapperBase):
|
||||
"""LayerActionMapper for handling the StableLM model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_symmetric: bool,
|
||||
attention_quant_bits: int,
|
||||
feedforward_quant_bits: int,
|
||||
embedding_quant_bits: int,
|
||||
backend: str,
|
||||
reader: _SafetensorsReader,
|
||||
):
|
||||
super().__init__(
|
||||
is_symmetric=is_symmetric,
|
||||
attention_quant_bits=attention_quant_bits,
|
||||
feedforward_quant_bits=feedforward_quant_bits,
|
||||
embedding_quant_bits=embedding_quant_bits,
|
||||
backend=backend,
|
||||
)
|
||||
self._reader = reader
|
||||
|
||||
def map_to_actions(
|
||||
self, layer_name: str
|
||||
) -> Optional[List[converter_base.QuantizationAction]]:
|
||||
"""Map the given layer name to actions."""
|
||||
tensor_value = self._reader.read_tensor_as_numpy(layer_name)
|
||||
quantize_axis = None
|
||||
quantize_bits = None
|
||||
layer_type = LayerType.get_layer_type(layer_name)
|
||||
|
||||
if (
|
||||
layer_type != LayerType.LAYER_NORM
|
||||
and layer_name.endswith(".weight")
|
||||
and layer_type != LayerType.LORA
|
||||
):
|
||||
quantize_axis = [0]
|
||||
if layer_type == LayerType.FEEDFORWARD:
|
||||
quantize_bits = self._feedforward_quant_bits
|
||||
elif layer_type == LayerType.ATTENTION:
|
||||
quantize_bits = self._attention_quant_bits
|
||||
if "o_proj" in layer_name:
|
||||
tensor_value = np.transpose(tensor_value)
|
||||
quantize_axis = [1]
|
||||
elif layer_type == LayerType.EMBEDDING:
|
||||
quantize_bits = self._embedding_quant_bits
|
||||
|
||||
target_name = self.update_target_name(layer_name)
|
||||
|
||||
actions = [
|
||||
converter_base.QuantizationAction(
|
||||
tensor_name=layer_name,
|
||||
tensor_value=tensor_value,
|
||||
target_name=target_name,
|
||||
quantize_axis=quantize_axis,
|
||||
quantize_bits=quantize_bits,
|
||||
pack_dim=0,
|
||||
)
|
||||
]
|
||||
return actions
|
||||
|
||||
def update_target_name(self, target_name: str) -> str:
|
||||
"""Updates the target name to match the tensor name convention."""
|
||||
target_name = target_name.replace("base_model.model.", "")
|
||||
target_name = target_name.replace(
|
||||
"model.layers.", "params.lm.transformer.x_layers_"
|
||||
)
|
||||
target_name = target_name.replace("mlp.up_proj", "ff_layer.ffn_layer1")
|
||||
target_name = target_name.replace("mlp.down_proj", "ff_layer.ffn_layer2")
|
||||
target_name = target_name.replace(
|
||||
"mlp.gate_proj", "ff_layer.ffn_layer1_gate"
|
||||
)
|
||||
target_name = target_name.replace("input_layernorm", "pre_layer_norm")
|
||||
target_name = target_name.replace(
|
||||
"pre_layer_norm.weight", "pre_layer_norm.scale"
|
||||
)
|
||||
target_name = target_name.replace(
|
||||
"post_attention_layernorm", "ff_layer.pre_layer_norm"
|
||||
)
|
||||
target_name = target_name.replace(
|
||||
"ff_layer.pre_layer_norm.weight", "ff_layer.pre_layer_norm.scale"
|
||||
)
|
||||
target_name = target_name.replace("self_attn.q_proj", "self_attention.q")
|
||||
target_name = target_name.replace("self_attn.k_proj", "self_attention.k")
|
||||
target_name = target_name.replace("self_attn.v_proj", "self_attention.v")
|
||||
target_name = target_name.replace("self_attn.o_proj", "self_attention.post")
|
||||
target_name = target_name.replace(
|
||||
"model.embed_tokens", "params.lm.softmax.logits_ffn"
|
||||
)
|
||||
target_name = target_name.replace("model.norm", "params.lm.final_ln")
|
||||
target_name = target_name.replace("final_ln.weight", "final_ln.scale")
|
||||
target_name = target_name.replace(".weight", ".w")
|
||||
|
||||
# For LoRA weights
|
||||
if "post" in target_name:
|
||||
target_name = target_name.replace("lora_A.w", "w_prime_right")
|
||||
target_name = target_name.replace("lora_B.w", "w_prime_left")
|
||||
else:
|
||||
target_name = target_name.replace("lora_A.w", "w_prime_left")
|
||||
target_name = target_name.replace("lora_B.w", "w_prime_right")
|
||||
|
||||
return target_name
|
||||
|
||||
|
||||
class SafetensorsCkptLoader(converter_base.CkptLoaderBase):
|
||||
"""CkptLoader implementation for loading the Safetensors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ckpt_path: str,
|
||||
is_symmetric: bool,
|
||||
attention_quant_bits: int,
|
||||
feedforward_quant_bits: int,
|
||||
embedding_quant_bits: int,
|
||||
special_model: str,
|
||||
backend: str,
|
||||
):
|
||||
"""Initializes the loader.
|
||||
|
||||
Args:
|
||||
ckpt_path: The filepath to the safetensors file.
|
||||
is_symmetric: Whether to apply symmetric or asymmetric quantization.
|
||||
attention_quant_bits: An integer that specify the target quantization bits
|
||||
(support 8 or 4) for the attention layers.
|
||||
feedforward_quant_bits: An integer that specify the target quantization
|
||||
bits (support 8 or 4) for the feedforward layers in each Transformer
|
||||
blocks.
|
||||
embedding_quant_bits: An integer that specify the target quantization bits
|
||||
(support 8 or 4) for the embedding (and the final projection) layers.
|
||||
special_model: A string that indicates which input model is and whether
|
||||
any special treatment is needed.
|
||||
backend: A string indicating the backend used when converting this model.
|
||||
Valid options are "cpu" and "gpu".
|
||||
"""
|
||||
super().__init__(
|
||||
ckpt_path,
|
||||
is_symmetric,
|
||||
attention_quant_bits,
|
||||
feedforward_quant_bits,
|
||||
embedding_quant_bits,
|
||||
)
|
||||
|
||||
self._special_model = special_model
|
||||
self._reader = _SafetensorsReader(ckpt_path)
|
||||
if special_model in ["STABLELM_4E1T_3B"]:
|
||||
self.mapper = StablelmMapper(
|
||||
is_symmetric,
|
||||
attention_quant_bits,
|
||||
feedforward_quant_bits,
|
||||
embedding_quant_bits,
|
||||
backend,
|
||||
self._reader,
|
||||
)
|
||||
elif special_model in ["PHI_2"]:
|
||||
self.mapper = PhiMapper(
|
||||
is_symmetric,
|
||||
attention_quant_bits,
|
||||
feedforward_quant_bits,
|
||||
embedding_quant_bits,
|
||||
backend,
|
||||
self._reader,
|
||||
)
|
||||
elif special_model in ["GEMMA_2B", "GEMMA_7B"]:
|
||||
self.mapper = GemmaMapper(
|
||||
is_symmetric,
|
||||
attention_quant_bits,
|
||||
feedforward_quant_bits,
|
||||
embedding_quant_bits,
|
||||
backend,
|
||||
self._reader,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown special model: {special_model}")
|
||||
|
||||
def load_to_actions(
|
||||
self,
|
||||
) -> Iterator[List[converter_base.QuantizationAction]]:
|
||||
tensor_names = self._reader.get_tensor_names()
|
||||
for tensor_name in tensor_names:
|
||||
tensor_actions = self.mapper.map_to_actions(tensor_name)
|
||||
if tensor_actions is None:
|
||||
continue
|
||||
yield tensor_actions
|
||||
+83
@@ -0,0 +1,83 @@
|
||||
# Copyright 2024 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for safetensors_converter."""
|
||||
|
||||
import os
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
from mediapipe.tasks.python.genai.converter import safetensors_converter
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/text'
|
||||
_SAFETENSORS_FILE = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, 'stablelm_3b_4e1t_test_weight.safetensors')
|
||||
)
|
||||
|
||||
|
||||
class SafetensorsConverterTest(parameterized.TestCase):
|
||||
VARIABLE_NAMES = [
|
||||
'model.embed_tokens.weight',
|
||||
'model.layers.0.input_layernorm.bias',
|
||||
'model.layers.0.input_layernorm.weight',
|
||||
'model.layers.0.mlp.down_proj.weight',
|
||||
'model.layers.0.mlp.gate_proj.weight',
|
||||
'model.layers.0.mlp.up_proj.weight',
|
||||
'model.layers.0.post_attention_layernorm.bias',
|
||||
'model.layers.0.post_attention_layernorm.weight',
|
||||
'model.layers.0.self_attn.k_proj.weight',
|
||||
'model.layers.0.self_attn.o_proj.weight',
|
||||
'model.layers.0.self_attn.q_proj.weight',
|
||||
'model.layers.0.self_attn.v_proj.weight',
|
||||
'model.norm.bias',
|
||||
'model.norm.weight',
|
||||
'lm_head.weight',
|
||||
]
|
||||
|
||||
def test_init(self):
|
||||
loader = safetensors_converter.SafetensorsCkptLoader(
|
||||
ckpt_path=_SAFETENSORS_FILE,
|
||||
is_symmetric=True,
|
||||
attention_quant_bits=8,
|
||||
feedforward_quant_bits=8,
|
||||
embedding_quant_bits=8,
|
||||
special_model='STABLELM_4E1T_3B',
|
||||
backend='gpu',
|
||||
)
|
||||
self.assertEqual(loader._ckpt_path, _SAFETENSORS_FILE)
|
||||
self.assertEqual(loader._is_symmetric, True)
|
||||
self.assertEqual(loader._attention_quant_bits, 8)
|
||||
self.assertEqual(loader._feedforward_quant_bits, 8)
|
||||
|
||||
@parameterized.product(
|
||||
quant_bits=(4, 8),
|
||||
)
|
||||
def test_load_to_actions(self, quant_bits):
|
||||
loader = safetensors_converter.SafetensorsCkptLoader(
|
||||
ckpt_path=_SAFETENSORS_FILE,
|
||||
is_symmetric=True,
|
||||
attention_quant_bits=8,
|
||||
feedforward_quant_bits=quant_bits,
|
||||
embedding_quant_bits=8,
|
||||
special_model='STABLELM_4E1T_3B',
|
||||
backend='gpu',
|
||||
)
|
||||
actions = loader.load_to_actions()
|
||||
self.assertLen(list(actions), 15)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
+112
@@ -0,0 +1,112 @@
|
||||
# Copyright 2024 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""ModelWriter for writing a set of weights as binary files."""
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from jax import numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
from mediapipe.tasks.python.genai.converter import converter_base
|
||||
from mediapipe.tasks.python.genai.converter import quantization_util
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def filemanager(filename: str, mode: str):
|
||||
try:
|
||||
with open(filename, mode) as f:
|
||||
yield f
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
def removeprefix(s, prefix):
|
||||
"""Removes the prefix from a string."""
|
||||
if s.startswith(prefix):
|
||||
return s[len(prefix) :]
|
||||
return s
|
||||
|
||||
|
||||
class WeightBinsWriter(converter_base.ModelWriterBase):
|
||||
"""A ModelWriter for writing a set of weights as binary files."""
|
||||
|
||||
def get_weight_info(self, var_name: str, weight: np.ndarray) -> str:
|
||||
"""Gets the string that describes the weights."""
|
||||
dtype_str = str(weight.dtype)
|
||||
shape_str = '_'.join(map(str, weight.shape))
|
||||
return f'mdl_vars.{var_name}.{dtype_str}.{shape_str}\n'
|
||||
|
||||
def write_variables(self, variables: Dict[str, Tuple[np.ndarray, bool]]):
|
||||
"""Writes variable to the binary files. One for each layer.
|
||||
|
||||
Args:
|
||||
variables: A dictionary that maps from the target variable names to the
|
||||
quantized tensor values along with a boolean that indicates whether to
|
||||
pack the values (only applicable for the 4-bit quantized tensors).
|
||||
"""
|
||||
weights_info = []
|
||||
for var_name, value in variables.items():
|
||||
output = value[0]
|
||||
if value[1]:
|
||||
# Squeeze the tensor to make sure it is a 1D array for packing.
|
||||
output = np.expand_dims(np.ravel(output), axis=-1)
|
||||
# Extra pack needed for 4 bit. We always pack the weights along the
|
||||
# first dimension since the tensor has already been squeezed.
|
||||
output = quantization_util.pack_4bit(output, 0, jnp.int8)
|
||||
if 'combined_qkv' in var_name:
|
||||
var_name = removeprefix(var_name, 'mld_vars.')
|
||||
var_name_q = var_name.replace('combined_qkv', 'q')
|
||||
var_name_k = var_name.replace('combined_qkv', 'k')
|
||||
var_name_v = var_name.replace('combined_qkv', 'v')
|
||||
if output.shape[0] == 3:
|
||||
weight_q, weight_k, weight_v = output
|
||||
assert weight_q.shape == weight_k.shape == weight_v.shape
|
||||
else: # LoRA right weight is shared across q, k, v
|
||||
weight_q = weight_k = weight_v = output
|
||||
weights_info.append(self.get_weight_info(var_name_q, weight_q))
|
||||
path_q = os.path.join(self._output_dir, var_name_q)
|
||||
with filemanager(path_q, 'wb') as f:
|
||||
f.write(weight_q.tobytes())
|
||||
weights_info.append(self.get_weight_info(var_name_k, weight_k))
|
||||
path_k = os.path.join(self._output_dir, var_name_k)
|
||||
with filemanager(path_k, 'wb') as f:
|
||||
f.write(weight_k.tobytes())
|
||||
path_v = os.path.join(self._output_dir, var_name_v)
|
||||
with filemanager(path_v, 'wb') as f:
|
||||
f.write(weight_v.tobytes())
|
||||
weights_info.append(self.get_weight_info(var_name_v, weight_v))
|
||||
else:
|
||||
if 'key' in var_name:
|
||||
var_name = var_name.replace('key', 'k')
|
||||
if 'query' in var_name:
|
||||
var_name = var_name.replace('query', 'q')
|
||||
if 'value' in var_name:
|
||||
var_name = var_name.replace('value', 'v')
|
||||
path = os.path.join(
|
||||
self._output_dir, removeprefix(var_name, 'mdl_vars.')
|
||||
)
|
||||
with filemanager(path, 'wb') as f:
|
||||
f.write(output.tobytes())
|
||||
weights_info.append(self.get_weight_info(var_name, output))
|
||||
|
||||
# Sort weights_info
|
||||
weights_info.sort()
|
||||
with filemanager(
|
||||
os.path.join(self._output_dir, 'layer_info.txt'), 'a'
|
||||
) as finfo:
|
||||
for line in weights_info:
|
||||
finfo.write(line + '\n')
|
||||
+62
@@ -0,0 +1,62 @@
|
||||
# Copyright 2024 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for pax_converter."""
|
||||
|
||||
import os
|
||||
|
||||
from absl import flags
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from mediapipe.tasks.python.genai.converter import weight_bins_writer
|
||||
|
||||
|
||||
class WeightBinsWriterTest(parameterized.TestCase):
|
||||
|
||||
def test_get_weight_info(self):
|
||||
output_dir = os.path.join(flags.FLAGS.test_tmpdir, 'output_dir')
|
||||
writer = weight_bins_writer.WeightBinsWriter(
|
||||
output_dir=output_dir, backend='cpu'
|
||||
)
|
||||
var_name = 'params.lm.softmax.logits_ffn.linear.w'
|
||||
weight_info = writer.get_weight_info(
|
||||
var_name, np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
|
||||
)
|
||||
self.assertEqual(
|
||||
weight_info,
|
||||
'mdl_vars.params.lm.softmax.logits_ffn.linear.w.float32.2_3\n',
|
||||
)
|
||||
|
||||
def test_load_to_actions(self):
|
||||
output_dir = os.path.join(flags.FLAGS.test_tmpdir, 'output_dir')
|
||||
writer = weight_bins_writer.WeightBinsWriter(
|
||||
output_dir=output_dir, backend='cpu'
|
||||
)
|
||||
variables = {
|
||||
'mdl_vars.params.lm.softmax.logits_ffn.linear.w': (
|
||||
np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32),
|
||||
False,
|
||||
),
|
||||
}
|
||||
writer.write_variables(variables)
|
||||
file_size = os.path.getsize(
|
||||
os.path.join(output_dir, 'params.lm.softmax.logits_ffn.linear.w')
|
||||
)
|
||||
self.assertEqual(file_size, 6 * 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
Reference in New Issue
Block a user