This commit is contained in:
2026-05-06 19:47:31 +07:00
parent 94d8682530
commit 12dbb7731b
9963 changed files with 2747894 additions and 0 deletions
@@ -0,0 +1,53 @@
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe Tasks Components Containers API."""
import mediapipe.tasks.python.components.containers.audio_data
import mediapipe.tasks.python.components.containers.bounding_box
import mediapipe.tasks.python.components.containers.category
import mediapipe.tasks.python.components.containers.classification_result
import mediapipe.tasks.python.components.containers.detections
import mediapipe.tasks.python.components.containers.embedding_result
import mediapipe.tasks.python.components.containers.landmark
import mediapipe.tasks.python.components.containers.landmark_detection_result
import mediapipe.tasks.python.components.containers.rect
AudioDataFormat = audio_data.AudioDataFormat
AudioData = audio_data.AudioData
BoundingBox = bounding_box.BoundingBox
Category = category.Category
Classifications = classification_result.Classifications
ClassificationResult = classification_result.ClassificationResult
Detection = detections.Detection
DetectionResult = detections.DetectionResult
Embedding = embedding_result.Embedding
EmbeddingResult = embedding_result.EmbeddingResult
Landmark = landmark.Landmark
NormalizedLandmark = landmark.NormalizedLandmark
LandmarksDetectionResult = landmark_detection_result.LandmarksDetectionResult
Rect = rect.Rect
NormalizedRect = rect.NormalizedRect
# Remove unnecessary modules to avoid duplication in API docs.
del audio_data
del bounding_box
del category
del classification_result
del detections
del embedding_result
del landmark
del landmark_detection_result
del rect
del mediapipe
@@ -0,0 +1,137 @@
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe audio data."""
import dataclasses
from typing import Optional
import numpy as np
@dataclasses.dataclass
class AudioDataFormat:
"""Audio format metadata.
Attributes:
num_channels: the number of channels of the audio data.
sample_rate: the audio sample rate.
"""
num_channels: int = 1
sample_rate: Optional[float] = None
class AudioData(object):
"""MediaPipe Tasks' audio container."""
def __init__(
self,
buffer_length: int,
audio_format: AudioDataFormat = AudioDataFormat()
) -> None:
"""Initializes the `AudioData` object.
Args:
buffer_length: the length of the audio buffer.
audio_format: the audio format metadata.
"""
self._audio_format = audio_format
self._buffer = np.zeros([buffer_length, self._audio_format.num_channels],
dtype=np.float32)
def clear(self):
"""Clears the internal buffer and fill it with zeros."""
self._buffer.fill(0)
def load_from_array(self,
src: np.ndarray,
offset: int = 0,
size: int = -1) -> None:
"""Loads the audio data from a NumPy array.
Args:
src: A NumPy source array contains the input audio.
offset: An optional offset for loading a slice of the `src` array to the
buffer.
size: An optional size parameter denoting the number of samples to load
from the `src` array.
Raises:
ValueError: If the input array has an incorrect shape or if
`offset` + `size` exceeds the length of the `src` array.
"""
if len(src.shape) == 1:
if self._audio_format.num_channels != 1:
raise ValueError(f"Input audio is mono, but the audio data is expected "
f"to have {self._audio_format.num_channels} channels.")
elif src.shape[1] != self._audio_format.num_channels:
raise ValueError(f"Input audio contains an invalid number of channels. "
f"Expect {self._audio_format.num_channels}.")
if size < 0:
size = len(src)
if offset + size > len(src):
raise ValueError(
f"Index out of range. offset {offset} + size {size} should be <= "
f"src's length: {len(src)}")
if len(src) >= len(self._buffer):
# If the internal buffer is shorter than the load target (src), copy
# values from the end of the src array to the internal buffer.
new_offset = offset + size - len(self._buffer)
new_size = len(self._buffer)
self._buffer = src[new_offset:new_offset + new_size].copy()
else:
# Shift the internal buffer backward and add the incoming data to the end
# of the buffer.
shift = size
self._buffer = np.roll(self._buffer, -shift, axis=0)
self._buffer[-shift:, :] = src[offset:offset + size].copy()
@classmethod
def create_from_array(cls,
src: np.ndarray,
sample_rate: Optional[float] = None) -> "AudioData":
"""Creates an `AudioData` object from a NumPy array.
Args:
src: A NumPy source array contains the input audio.
sample_rate: the optional audio sample rate.
Returns:
An `AudioData` object that contains a copy of the NumPy source array as
the data.
"""
obj = cls(
buffer_length=src.shape[0],
audio_format=AudioDataFormat(
num_channels=1 if len(src.shape) == 1 else src.shape[1],
sample_rate=sample_rate))
obj.load_from_array(src)
return obj
@property
def audio_format(self) -> AudioDataFormat:
"""Gets the audio format of the audio."""
return self._audio_format
@property
def buffer_length(self) -> int:
"""Gets the sample count of the audio."""
return self._buffer.shape[0]
@property
def buffer(self) -> np.ndarray:
"""Gets the internal buffer."""
return self._buffer
@@ -0,0 +1,73 @@
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Bounding box data class."""
import dataclasses
from typing import Any
from mediapipe.framework.formats import location_data_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_BoundingBoxProto = location_data_pb2.LocationData.BoundingBox
@dataclasses.dataclass
class BoundingBox:
"""An integer bounding box, axis aligned.
Attributes:
origin_x: The X coordinate of the top-left corner, in pixels.
origin_y: The Y coordinate of the top-left corner, in pixels.
width: The width of the bounding box, in pixels.
height: The height of the bounding box, in pixels.
"""
origin_x: int
origin_y: int
width: int
height: int
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _BoundingBoxProto:
"""Generates a BoundingBox protobuf object."""
return _BoundingBoxProto(
xmin=self.origin_x,
ymin=self.origin_y,
width=self.width,
height=self.height,
)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _BoundingBoxProto) -> 'BoundingBox':
"""Creates a `BoundingBox` object from the given protobuf object."""
return BoundingBox(
origin_x=pb2_obj.xmin,
origin_y=pb2_obj.ymin,
width=pb2_obj.width,
height=pb2_obj.height)
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, BoundingBox):
return False
return self.to_pb2().__eq__(other.to_pb2())
@@ -0,0 +1,78 @@
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Category data class."""
import dataclasses
from typing import Any, Optional
from mediapipe.framework.formats import classification_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_ClassificationProto = classification_pb2.Classification
@dataclasses.dataclass
class Category:
"""A classification category.
Category is a util class, contains a label, its display name, a float
value as score, and the index of the label in the corresponding label file.
Typically it's used as the result of classification tasks.
Attributes:
index: The index of the label in the corresponding label file.
score: The probability score of this label category.
display_name: The display name of the label, which may be translated for
different locales. For example, a label, "apple", may be translated into
Spanish for display purpose, so that the `display_name` is "manzana".
category_name: The label of this category object.
"""
index: Optional[int] = None
score: Optional[float] = None
display_name: Optional[str] = None
category_name: Optional[str] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ClassificationProto:
"""Generates a Category protobuf object."""
return _ClassificationProto(
index=self.index,
score=self.score,
label=self.category_name,
display_name=self.display_name)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _ClassificationProto) -> 'Category':
"""Creates a `Category` object from the given protobuf object."""
return Category(
index=pb2_obj.index,
score=pb2_obj.score,
display_name=pb2_obj.display_name,
category_name=pb2_obj.label)
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, Category):
return False
return self.to_pb2().__eq__(other.to_pb2())
@@ -0,0 +1,111 @@
# Copyright 2022 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classifications data class."""
import dataclasses
from typing import List, Optional
from mediapipe.framework.formats import classification_pb2
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_ClassificationProto = classification_pb2.Classification
_ClassificationListProto = classification_pb2.ClassificationList
_ClassificationsProto = classifications_pb2.Classifications
_ClassificationResultProto = classifications_pb2.ClassificationResult
@dataclasses.dataclass
class Classifications:
"""Represents the classification results for a given classifier head.
Attributes:
categories: The array of predicted categories, usually sorted by descending
scores (e.g. from high to low probability).
head_index: The index of the classifier head these categories refer to. This
is useful for multi-head models.
head_name: The name of the classifier head, which is the corresponding
tensor metadata name.
"""
categories: List[category_module.Category]
head_index: int
head_name: Optional[str] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ClassificationsProto:
"""Generates a Classifications protobuf object."""
classification_list_proto = _ClassificationListProto()
for category in self.categories:
classification_proto = category.to_pb2()
classification_list_proto.classification.append(classification_proto)
return _ClassificationsProto(
classification_list=classification_list_proto,
head_index=self.head_index,
head_name=self.head_name)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications':
"""Creates a `Classifications` object from the given protobuf object."""
categories = []
for classification in pb2_obj.classification_list.classification:
categories.append(
category_module.Category.create_from_pb2(classification))
return Classifications(
categories=categories,
head_index=pb2_obj.head_index,
head_name=pb2_obj.head_name)
@dataclasses.dataclass
class ClassificationResult:
"""Contains the classification results of a model.
Attributes:
classifications: A list of `Classifications` objects, each for a head of the
model.
timestamp_ms: The optional timestamp (in milliseconds) of the start of the
chunk of data corresponding to these results. This is only used for
classification on time series (e.g. audio classification). In these use
cases, the amount of data to process might exceed the maximum size that
the model can process: to solve this, the input data is split into
multiple chunks starting at different timestamps.
"""
classifications: List[Classifications]
timestamp_ms: Optional[int] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ClassificationResultProto:
"""Generates a ClassificationResult protobuf object."""
return _ClassificationResultProto(
classifications=[
classification.to_pb2() for classification in self.classifications
],
timestamp_ms=self.timestamp_ms)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(
cls, pb2_obj: _ClassificationResultProto) -> 'ClassificationResult':
"""Creates a `ClassificationResult` object from the given protobuf object.
"""
return ClassificationResult(
classifications=[
Classifications.create_from_pb2(classification)
for classification in pb2_obj.classifications
],
timestamp_ms=pb2_obj.timestamp_ms)
@@ -0,0 +1,181 @@
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Detections data class."""
import dataclasses
from typing import Any, List, Optional
from mediapipe.framework.formats import detection_pb2
from mediapipe.framework.formats import location_data_pb2
from mediapipe.tasks.python.components.containers import bounding_box as bounding_box_module
from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.components.containers import keypoint as keypoint_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_DetectionListProto = detection_pb2.DetectionList
_DetectionProto = detection_pb2.Detection
_LocationDataProto = location_data_pb2.LocationData
@dataclasses.dataclass
class Detection:
"""Represents one detected object in the object detector's results.
Attributes:
bounding_box: A BoundingBox object.
categories: A list of Category objects.
keypoints: A list of NormalizedKeypoint objects.
"""
bounding_box: bounding_box_module.BoundingBox
categories: List[category_module.Category]
keypoints: Optional[List[keypoint_module.NormalizedKeypoint]] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _DetectionProto:
"""Generates a Detection protobuf object."""
labels = []
label_ids = []
scores = []
display_names = []
relative_keypoints = []
for category in self.categories:
scores.append(category.score)
if category.index:
label_ids.append(category.index)
if category.category_name:
labels.append(category.category_name)
if category.display_name:
display_names.append(category.display_name)
if self.keypoints:
for keypoint in self.keypoints:
relative_keypoint_proto = _LocationDataProto.RelativeKeypoint()
if keypoint.x:
relative_keypoint_proto.x = keypoint.x
if keypoint.y:
relative_keypoint_proto.y = keypoint.y
if keypoint.label:
relative_keypoint_proto.keypoint_label = keypoint.label
if keypoint.score:
relative_keypoint_proto.score = keypoint.score
relative_keypoints.append(relative_keypoint_proto)
return _DetectionProto(
label=labels,
label_id=label_ids,
score=scores,
display_name=display_names,
location_data=_LocationDataProto(
format=_LocationDataProto.Format.BOUNDING_BOX,
bounding_box=self.bounding_box.to_pb2(),
relative_keypoints=relative_keypoints,
),
)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _DetectionProto) -> 'Detection':
"""Creates a `Detection` object from the given protobuf object."""
categories = []
keypoints = []
for idx, score in enumerate(pb2_obj.score):
categories.append(
category_module.Category(
score=score,
index=pb2_obj.label_id[idx]
if idx < len(pb2_obj.label_id)
else None,
category_name=pb2_obj.label[idx]
if idx < len(pb2_obj.label)
else None,
display_name=pb2_obj.display_name[idx]
if idx < len(pb2_obj.display_name)
else None,
)
)
if pb2_obj.location_data.relative_keypoints:
for idx, elem in enumerate(pb2_obj.location_data.relative_keypoints):
keypoints.append(
keypoint_module.NormalizedKeypoint(
x=elem.x,
y=elem.y,
label=elem.keypoint_label,
score=elem.score,
)
)
return Detection(
bounding_box=bounding_box_module.BoundingBox.create_from_pb2(
pb2_obj.location_data.bounding_box
),
categories=categories,
keypoints=keypoints,
)
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, Detection):
return False
return self.to_pb2().__eq__(other.to_pb2())
@dataclasses.dataclass
class DetectionResult:
"""Represents the list of detected objects.
Attributes:
detections: A list of `Detection` objects.
"""
detections: List[Detection]
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _DetectionListProto:
"""Generates a DetectionList protobuf object."""
return _DetectionListProto(
detection=[detection.to_pb2() for detection in self.detections])
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _DetectionListProto) -> 'DetectionResult':
"""Creates a `DetectionResult` object from the given protobuf object."""
return DetectionResult(detections=[
Detection.create_from_pb2(detection) for detection in pb2_obj.detection
])
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, DetectionResult):
return False
return self.to_pb2().__eq__(other.to_pb2())
@@ -0,0 +1,89 @@
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Embeddings data class."""
import dataclasses
from typing import Optional, List
import numpy as np
from mediapipe.tasks.cc.components.containers.proto import embeddings_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_FloatEmbeddingProto = embeddings_pb2.FloatEmbedding
_QuantizedEmbeddingProto = embeddings_pb2.QuantizedEmbedding
_EmbeddingProto = embeddings_pb2.Embedding
_EmbeddingResultProto = embeddings_pb2.EmbeddingResult
@dataclasses.dataclass
class Embedding:
"""Embedding result for a given embedder head.
Attributes:
embedding: The actual embedding, either floating-point or scalar-quantized.
head_index: The index of the embedder head that produced this embedding.
This is useful for multi-head models.
head_name: The name of the embedder head, which is the corresponding tensor
metadata name (if any). This is useful for multi-head models.
"""
embedding: np.ndarray
head_index: Optional[int] = None
head_name: Optional[str] = None
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _EmbeddingProto) -> 'Embedding':
"""Creates a `Embedding` object from the given protobuf object."""
quantized_embedding = np.array(
bytearray(pb2_obj.quantized_embedding.values))
float_embedding = np.array(pb2_obj.float_embedding.values, dtype=float)
if not pb2_obj.quantized_embedding.values:
return Embedding(
embedding=float_embedding,
head_index=pb2_obj.head_index,
head_name=pb2_obj.head_name)
else:
return Embedding(
embedding=quantized_embedding,
head_index=pb2_obj.head_index,
head_name=pb2_obj.head_name)
@dataclasses.dataclass
class EmbeddingResult:
"""Embedding results for a given embedder model.
Attributes:
embeddings: A list of `Embedding` objects.
timestamp_ms: The optional timestamp (in milliseconds) of the start of the
chunk of data corresponding to these results. This is only used for
embedding extraction on time series (e.g. audio embedding). In these use
cases, the amount of data to process might exceed the maximum size that
the model can process: to solve this, the input data is split into
multiple chunks starting at different timestamps.
"""
embeddings: List[Embedding]
timestamp_ms: Optional[int] = None
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _EmbeddingResultProto) -> 'EmbeddingResult':
"""Creates a `EmbeddingResult` object from the given protobuf object."""
return EmbeddingResult(embeddings=[
Embedding.create_from_pb2(embedding) for embedding in pb2_obj.embeddings
])
@@ -0,0 +1,77 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keypoint data class."""
import dataclasses
from typing import Any, Optional
from mediapipe.framework.formats import location_data_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_RelativeKeypointProto = location_data_pb2.LocationData.RelativeKeypoint
@dataclasses.dataclass
class NormalizedKeypoint:
"""A normalized keypoint.
Normalized keypoint represents a point in 2D space with x, y coordinates.
x and y are normalized to [0.0, 1.0] by the image width and height
respectively.
Attributes:
x: The x coordinates of the normalized keypoint.
y: The y coordinates of the normalized keypoint.
label: The optional label of the keypoint.
score: The score of the keypoint.
"""
x: Optional[float] = None
y: Optional[float] = None
label: Optional[str] = None
score: Optional[float] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _RelativeKeypointProto:
"""Generates a RelativeKeypoint protobuf object."""
return _RelativeKeypointProto(
x=self.x, y=self.y, keypoint_label=self.label, score=self.score
)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(
cls, pb2_obj: _RelativeKeypointProto
) -> 'NormalizedKeypoint':
"""Creates a `NormalizedKeypoint` object from the given protobuf object."""
return NormalizedKeypoint(
x=pb2_obj.x,
y=pb2_obj.y,
label=pb2_obj.keypoint_label,
score=pb2_obj.score,
)
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, NormalizedKeypoint):
return False
return self.to_pb2().__eq__(other.to_pb2())
@@ -0,0 +1,122 @@
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Landmark data class."""
import dataclasses
from typing import Optional
from mediapipe.framework.formats import landmark_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_LandmarkProto = landmark_pb2.Landmark
_NormalizedLandmarkProto = landmark_pb2.NormalizedLandmark
@dataclasses.dataclass
class Landmark:
"""A landmark that can have 1 to 3 dimensions.
Use x for 1D points, (x, y) for 2D points and (x, y, z) for 3D points.
Attributes:
x: The x coordinate.
y: The y coordinate.
z: The z coordinate.
visibility: Landmark visibility. Should stay unset if not supported. Float
score of whether landmark is visible or occluded by other objects.
Landmark considered as invisible also if it is not present on the screen
(out of scene bounds). Depending on the model, visibility value is either
a sigmoid or an argument of sigmoid.
presence: Landmark presence. Should stay unset if not supported. Float score
of whether landmark is present on the scene (located within scene bounds).
Depending on the model, presence value is either a result of sigmoid or an
argument of sigmoid function to get landmark presence probability.
"""
x: Optional[float] = None
y: Optional[float] = None
z: Optional[float] = None
visibility: Optional[float] = None
presence: Optional[float] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _LandmarkProto:
"""Generates a Landmark protobuf object."""
return _LandmarkProto(
x=self.x,
y=self.y,
z=self.z,
visibility=self.visibility,
presence=self.presence)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _LandmarkProto) -> 'Landmark':
"""Creates a `Landmark` object from the given protobuf object."""
return Landmark(
x=pb2_obj.x,
y=pb2_obj.y,
z=pb2_obj.z,
visibility=pb2_obj.visibility,
presence=pb2_obj.presence)
@dataclasses.dataclass
class NormalizedLandmark:
"""A normalized version of above Landmark proto.
All coordinates should be within [0, 1].
Attributes:
x: The normalized x coordinate.
y: The normalized y coordinate.
z: The normalized z coordinate.
visibility: Landmark visibility. Should stay unset if not supported. Float
score of whether landmark is visible or occluded by other objects.
Landmark considered as invisible also if it is not present on the screen
(out of scene bounds). Depending on the model, visibility value is either
a sigmoid or an argument of sigmoid.
presence: Landmark presence. Should stay unset if not supported. Float score
of whether landmark is present on the scene (located within scene bounds).
Depending on the model, presence value is either a result of sigmoid or an
argument of sigmoid function to get landmark presence probability.
"""
x: Optional[float] = None
y: Optional[float] = None
z: Optional[float] = None
visibility: Optional[float] = None
presence: Optional[float] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _NormalizedLandmarkProto:
"""Generates a NormalizedLandmark protobuf object."""
return _NormalizedLandmarkProto(
x=self.x,
y=self.y,
z=self.z,
visibility=self.visibility,
presence=self.presence)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(
cls, pb2_obj: _NormalizedLandmarkProto) -> 'NormalizedLandmark':
"""Creates a `NormalizedLandmark` object from the given protobuf object."""
return NormalizedLandmark(
x=pb2_obj.x,
y=pb2_obj.y,
z=pb2_obj.z,
visibility=pb2_obj.visibility,
presence=pb2_obj.presence)
@@ -0,0 +1,106 @@
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Landmarks Detection Result data class."""
import dataclasses
from typing import Optional, List
from mediapipe.framework.formats import classification_pb2
from mediapipe.framework.formats import landmark_pb2
from mediapipe.tasks.cc.components.containers.proto import landmarks_detection_result_pb2
from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.components.containers import landmark as landmark_module
from mediapipe.tasks.python.components.containers import rect as rect_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_LandmarksDetectionResultProto = landmarks_detection_result_pb2.LandmarksDetectionResult
_ClassificationProto = classification_pb2.Classification
_ClassificationListProto = classification_pb2.ClassificationList
_LandmarkListProto = landmark_pb2.LandmarkList
_NormalizedLandmarkListProto = landmark_pb2.NormalizedLandmarkList
_NormalizedRect = rect_module.NormalizedRect
_Category = category_module.Category
_NormalizedLandmark = landmark_module.NormalizedLandmark
_Landmark = landmark_module.Landmark
@dataclasses.dataclass
class LandmarksDetectionResult:
"""Represents the landmarks detection result.
Attributes:
landmarks: A list of `NormalizedLandmark` objects.
categories: A list of `Category` objects.
world_landmarks: A list of `Landmark` objects.
rect: A `NormalizedRect` object.
"""
landmarks: Optional[List[_NormalizedLandmark]]
categories: Optional[List[_Category]]
world_landmarks: Optional[List[_Landmark]]
rect: _NormalizedRect
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _LandmarksDetectionResultProto:
"""Generates a LandmarksDetectionResult protobuf object."""
landmarks = _NormalizedLandmarkListProto()
classifications = _ClassificationListProto()
world_landmarks = _LandmarkListProto()
for landmark in self.landmarks:
landmarks.landmark.append(landmark.to_pb2())
for category in self.categories:
classifications.classification.append(
_ClassificationProto(
index=category.index,
score=category.score,
label=category.category_name,
display_name=category.display_name))
return _LandmarksDetectionResultProto(
landmarks=landmarks,
classifications=classifications,
world_landmarks=world_landmarks,
rect=self.rect.to_pb2())
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(
cls,
pb2_obj: _LandmarksDetectionResultProto) -> 'LandmarksDetectionResult':
"""Creates a `LandmarksDetectionResult` object from the given protobuf object."""
categories = []
landmarks = []
world_landmarks = []
for classification in pb2_obj.classifications.classification:
categories.append(
category_module.Category(
score=classification.score,
index=classification.index,
category_name=classification.label,
display_name=classification.display_name))
for landmark in pb2_obj.landmarks.landmark:
landmarks.append(_NormalizedLandmark.create_from_pb2(landmark))
for landmark in pb2_obj.world_landmarks.landmark:
world_landmarks.append(_Landmark.create_from_pb2(landmark))
return LandmarksDetectionResult(
landmarks=landmarks,
categories=categories,
world_landmarks=world_landmarks,
rect=_NormalizedRect.create_from_pb2(pb2_obj.rect))
@@ -0,0 +1,109 @@
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rect data class."""
import dataclasses
from typing import Any, Optional
from mediapipe.framework.formats import rect_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_NormalizedRectProto = rect_pb2.NormalizedRect
@dataclasses.dataclass
class Rect:
"""A rectangle, used as part of detection results or as input region-of-interest.
The coordinates are normalized wrt the image dimensions, i.e. generally in
[0,1] but they may exceed these bounds if describing a region overlapping the
image. The origin is on the top-left corner of the image.
Attributes:
left: The X coordinate of the left side of the rectangle.
top: The Y coordinate of the top of the rectangle.
right: The X coordinate of the right side of the rectangle.
bottom: The Y coordinate of the bottom of the rectangle.
"""
left: float
top: float
right: float
bottom: float
@dataclasses.dataclass
class NormalizedRect:
"""A rectangle with rotation in normalized coordinates.
Location of the center of the rectangle in image coordinates. The (0.0, 0.0)
point is at the (top, left) corner.
The values of box center location and size are within [0, 1].
Attributes:
x_center: The normalized X coordinate of the rectangle, in image
coordinates.
y_center: The normalized Y coordinate of the rectangle, in image
coordinates.
width: The width of the rectangle.
height: The height of the rectangle.
rotation: Rotation angle is clockwise in radians.
rect_id: Optional unique id to help associate different rectangles to each
other.
"""
x_center: float
y_center: float
width: float
height: float
rotation: Optional[float] = 0.0
rect_id: Optional[int] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _NormalizedRectProto:
"""Generates a NormalizedRect protobuf object."""
return _NormalizedRectProto(
x_center=self.x_center,
y_center=self.y_center,
width=self.width,
height=self.height,
rotation=self.rotation,
rect_id=self.rect_id)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _NormalizedRectProto) -> 'NormalizedRect':
"""Creates a `NormalizedRect` object from the given protobuf object."""
return NormalizedRect(
x_center=pb2_obj.x_center,
y_center=pb2_obj.y_center,
width=pb2_obj.width,
height=pb2_obj.height,
rotation=pb2_obj.rotation,
rect_id=pb2_obj.rect_id)
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, NormalizedRect):
return False
return self.to_pb2().__eq__(other.to_pb2())