hand
This commit is contained in:
@@ -0,0 +1,13 @@
|
||||
# Copyright 2022 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
+190
@@ -0,0 +1,190 @@
|
||||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for face aligner."""
|
||||
|
||||
import enum
|
||||
import os
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.tasks.python.components.containers import rect
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
from mediapipe.tasks.python.vision import face_aligner
|
||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_Rect = rect.Rect
|
||||
_Image = image_module.Image
|
||||
_FaceAligner = face_aligner.FaceAligner
|
||||
_FaceAlignerOptions = face_aligner.FaceAlignerOptions
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
|
||||
_MODEL = 'face_landmarker_v2.task'
|
||||
_LARGE_FACE_IMAGE = 'portrait.jpg'
|
||||
_MODEL_IMAGE_SIZE = 256
|
||||
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
||||
|
||||
|
||||
class ModelFileType(enum.Enum):
|
||||
FILE_CONTENT = 1
|
||||
FILE_NAME = 2
|
||||
|
||||
|
||||
class FaceAlignerTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _LARGE_FACE_IMAGE)
|
||||
)
|
||||
)
|
||||
self.model_path = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _MODEL)
|
||||
)
|
||||
|
||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||
# Creates with default option and valid model file successfully.
|
||||
with _FaceAligner.create_from_model_path(self.model_path) as aligner:
|
||||
self.assertIsInstance(aligner, _FaceAligner)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_path(self):
|
||||
# Creates with options containing model file successfully.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _FaceAlignerOptions(base_options=base_options)
|
||||
with _FaceAligner.create_from_options(options) as aligner:
|
||||
self.assertIsInstance(aligner, _FaceAligner)
|
||||
|
||||
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'
|
||||
):
|
||||
base_options = _BaseOptions(
|
||||
model_asset_path='/path/to/invalid/model.tflite'
|
||||
)
|
||||
options = _FaceAlignerOptions(base_options=base_options)
|
||||
_FaceAligner.create_from_options(options)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_content(self):
|
||||
# Creates with options containing model content successfully.
|
||||
with open(self.model_path, 'rb') as f:
|
||||
base_options = _BaseOptions(model_asset_buffer=f.read())
|
||||
options = _FaceAlignerOptions(base_options=base_options)
|
||||
aligner = _FaceAligner.create_from_options(options)
|
||||
self.assertIsInstance(aligner, _FaceAligner)
|
||||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME, _LARGE_FACE_IMAGE),
|
||||
(ModelFileType.FILE_CONTENT, _LARGE_FACE_IMAGE),
|
||||
)
|
||||
def test_align(self, model_file_type, image_file_name):
|
||||
# Load the test image.
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, image_file_name)
|
||||
)
|
||||
)
|
||||
# Creates aligner.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _FaceAlignerOptions(base_options=base_options)
|
||||
aligner = _FaceAligner.create_from_options(options)
|
||||
|
||||
# Performs face alignment on the input.
|
||||
aligned_image = aligner.align(self.test_image)
|
||||
self.assertIsInstance(aligned_image, _Image)
|
||||
# Closes the aligner explicitly when the aligner is not used in
|
||||
# a context.
|
||||
aligner.close()
|
||||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME, _LARGE_FACE_IMAGE),
|
||||
(ModelFileType.FILE_CONTENT, _LARGE_FACE_IMAGE),
|
||||
)
|
||||
def test_align_in_context(self, model_file_type, image_file_name):
|
||||
# Load the test image.
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, image_file_name)
|
||||
)
|
||||
)
|
||||
# Creates aligner.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _FaceAlignerOptions(base_options=base_options)
|
||||
with _FaceAligner.create_from_options(options) as aligner:
|
||||
# Performs face alignment on the input.
|
||||
aligned_image = aligner.align(self.test_image)
|
||||
self.assertIsInstance(aligned_image, _Image)
|
||||
self.assertEqual(aligned_image.width, _MODEL_IMAGE_SIZE)
|
||||
self.assertEqual(aligned_image.height, _MODEL_IMAGE_SIZE)
|
||||
|
||||
def test_align_succeeds_with_region_of_interest(self):
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _FaceAlignerOptions(base_options=base_options)
|
||||
with _FaceAligner.create_from_options(options) as aligner:
|
||||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _LARGE_FACE_IMAGE)
|
||||
)
|
||||
)
|
||||
# Region-of-interest around the face.
|
||||
roi = _Rect(left=0.32, top=0.02, right=0.67, bottom=0.32)
|
||||
image_processing_options = _ImageProcessingOptions(roi)
|
||||
# Performs face alignment on the input.
|
||||
aligned_image = aligner.align(test_image, image_processing_options)
|
||||
self.assertIsInstance(aligned_image, _Image)
|
||||
self.assertEqual(aligned_image.width, _MODEL_IMAGE_SIZE)
|
||||
self.assertEqual(aligned_image.height, _MODEL_IMAGE_SIZE)
|
||||
|
||||
def test_align_succeeds_with_no_face_detected(self):
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _FaceAlignerOptions(base_options=base_options)
|
||||
with _FaceAligner.create_from_options(options) as aligner:
|
||||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _LARGE_FACE_IMAGE)
|
||||
)
|
||||
)
|
||||
# Region-of-interest that doesn't contain a human face.
|
||||
roi = _Rect(left=0.1, top=0.1, right=0.2, bottom=0.2)
|
||||
image_processing_options = _ImageProcessingOptions(roi)
|
||||
# Performs face alignment on the input.
|
||||
aligned_image = aligner.align(test_image, image_processing_options)
|
||||
self.assertIsNone(aligned_image)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
+523
@@ -0,0 +1,523 @@
|
||||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for face detector."""
|
||||
|
||||
import enum
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
from google.protobuf import text_format
|
||||
from mediapipe.framework.formats import detection_pb2
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.tasks.python.components.containers import bounding_box as bounding_box_module
|
||||
from mediapipe.tasks.python.components.containers import category as category_module
|
||||
from mediapipe.tasks.python.components.containers import detections as detections_module
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
from mediapipe.tasks.python.vision import face_detector
|
||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||
|
||||
FaceDetectorResult = detections_module.DetectionResult
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_Category = category_module.Category
|
||||
_BoundingBox = bounding_box_module.BoundingBox
|
||||
_Detection = detections_module.Detection
|
||||
_Image = image_module.Image
|
||||
_FaceDetector = face_detector.FaceDetector
|
||||
_FaceDetectorOptions = face_detector.FaceDetectorOptions
|
||||
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
|
||||
_SHORT_RANGE_BLAZE_FACE_MODEL = 'face_detection_short_range.tflite'
|
||||
_PORTRAIT_IMAGE = 'portrait.jpg'
|
||||
_PORTRAIT_EXPECTED_DETECTION = 'portrait_expected_detection.pbtxt'
|
||||
_PORTRAIT_ROTATED_IMAGE = 'portrait_rotated.jpg'
|
||||
_PORTRAIT_ROTATED_EXPECTED_DETECTION = (
|
||||
'portrait_rotated_expected_detection.pbtxt'
|
||||
)
|
||||
_CAT_IMAGE = 'cat.jpg'
|
||||
_KEYPOINT_ERROR_THRESHOLD = 1e-2
|
||||
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
||||
|
||||
|
||||
def _get_expected_face_detector_result(file_name: str) -> FaceDetectorResult:
|
||||
face_detection_result_file_path = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, file_name)
|
||||
)
|
||||
with open(face_detection_result_file_path, 'rb') as f:
|
||||
face_detection_proto = detection_pb2.Detection()
|
||||
text_format.Parse(f.read(), face_detection_proto)
|
||||
face_detection = detections_module.Detection.create_from_pb2(
|
||||
face_detection_proto
|
||||
)
|
||||
return FaceDetectorResult(detections=[face_detection])
|
||||
|
||||
|
||||
class ModelFileType(enum.Enum):
|
||||
FILE_CONTENT = 1
|
||||
FILE_NAME = 2
|
||||
|
||||
|
||||
class FaceDetectorTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _PORTRAIT_IMAGE)
|
||||
)
|
||||
)
|
||||
self.model_path = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _SHORT_RANGE_BLAZE_FACE_MODEL)
|
||||
)
|
||||
|
||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||
# Creates with default option and valid model file successfully.
|
||||
with _FaceDetector.create_from_model_path(self.model_path) as detector:
|
||||
self.assertIsInstance(detector, _FaceDetector)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_path(self):
|
||||
# Creates with options containing model file successfully.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _FaceDetectorOptions(base_options=base_options)
|
||||
with _FaceDetector.create_from_options(options) as detector:
|
||||
self.assertIsInstance(detector, _FaceDetector)
|
||||
|
||||
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'
|
||||
):
|
||||
base_options = _BaseOptions(
|
||||
model_asset_path='/path/to/invalid/model.tflite'
|
||||
)
|
||||
options = _FaceDetectorOptions(base_options=base_options)
|
||||
_FaceDetector.create_from_options(options)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_content(self):
|
||||
# Creates with options containing model content successfully.
|
||||
with open(self.model_path, 'rb') as f:
|
||||
base_options = _BaseOptions(model_asset_buffer=f.read())
|
||||
options = _FaceDetectorOptions(base_options=base_options)
|
||||
detector = _FaceDetector.create_from_options(options)
|
||||
self.assertIsInstance(detector, _FaceDetector)
|
||||
|
||||
def _expect_keypoints_correct(self, actual_keypoints, expected_keypoints):
|
||||
self.assertLen(actual_keypoints, len(expected_keypoints))
|
||||
for i in range(len(actual_keypoints)):
|
||||
self.assertAlmostEqual(
|
||||
actual_keypoints[i].x,
|
||||
expected_keypoints[i].x,
|
||||
delta=_KEYPOINT_ERROR_THRESHOLD,
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
actual_keypoints[i].y,
|
||||
expected_keypoints[i].y,
|
||||
delta=_KEYPOINT_ERROR_THRESHOLD,
|
||||
)
|
||||
|
||||
def _expect_face_detector_results_correct(
|
||||
self, actual_results, expected_results
|
||||
):
|
||||
self.assertLen(actual_results.detections, len(expected_results.detections))
|
||||
for i in range(len(actual_results.detections)):
|
||||
actual_bbox = actual_results.detections[i].bounding_box
|
||||
expected_bbox = expected_results.detections[i].bounding_box
|
||||
self.assertEqual(actual_bbox, expected_bbox)
|
||||
self.assertNotEmpty(actual_results.detections[i].keypoints)
|
||||
self._expect_keypoints_correct(
|
||||
actual_results.detections[i].keypoints,
|
||||
expected_results.detections[i].keypoints,
|
||||
)
|
||||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME, _PORTRAIT_EXPECTED_DETECTION),
|
||||
(ModelFileType.FILE_CONTENT, _PORTRAIT_EXPECTED_DETECTION),
|
||||
)
|
||||
def test_detect(self, model_file_type, expected_detection_result_file):
|
||||
# Creates detector.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _FaceDetectorOptions(base_options=base_options)
|
||||
detector = _FaceDetector.create_from_options(options)
|
||||
|
||||
# Performs face detection on the input.
|
||||
detection_result = detector.detect(self.test_image)
|
||||
# Comparing results.
|
||||
expected_detection_result = _get_expected_face_detector_result(
|
||||
expected_detection_result_file
|
||||
)
|
||||
self._expect_face_detector_results_correct(
|
||||
detection_result, expected_detection_result
|
||||
)
|
||||
# Closes the detector explicitly when the detector is not used in
|
||||
# a context.
|
||||
detector.close()
|
||||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME, _PORTRAIT_EXPECTED_DETECTION),
|
||||
(ModelFileType.FILE_CONTENT, _PORTRAIT_EXPECTED_DETECTION),
|
||||
)
|
||||
def test_detect_in_context(
|
||||
self, model_file_type, expected_detection_result_file
|
||||
):
|
||||
# Creates detector.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _FaceDetectorOptions(base_options=base_options)
|
||||
|
||||
with _FaceDetector.create_from_options(options) as detector:
|
||||
# Performs face detection on the input.
|
||||
detection_result = detector.detect(self.test_image)
|
||||
# Comparing results.
|
||||
expected_detection_result = _get_expected_face_detector_result(
|
||||
expected_detection_result_file
|
||||
)
|
||||
self._expect_face_detector_results_correct(
|
||||
detection_result, expected_detection_result
|
||||
)
|
||||
|
||||
def test_detect_succeeds_with_rotated_image(self):
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _FaceDetectorOptions(base_options=base_options)
|
||||
with _FaceDetector.create_from_options(options) as detector:
|
||||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _PORTRAIT_ROTATED_IMAGE)
|
||||
)
|
||||
)
|
||||
# Rotated input image.
|
||||
image_processing_options = _ImageProcessingOptions(rotation_degrees=-90)
|
||||
# Performs face detection on the input.
|
||||
detection_result = detector.detect(test_image, image_processing_options)
|
||||
# Comparing results.
|
||||
expected_detection_result = _get_expected_face_detector_result(
|
||||
_PORTRAIT_ROTATED_EXPECTED_DETECTION
|
||||
)
|
||||
self._expect_face_detector_results_correct(
|
||||
detection_result, expected_detection_result
|
||||
)
|
||||
|
||||
def test_empty_detection_outputs(self):
|
||||
# Load a test image with no faces.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))
|
||||
)
|
||||
options = _FaceDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path)
|
||||
)
|
||||
with _FaceDetector.create_from_options(options) as detector:
|
||||
# Performs face detection on the input.
|
||||
detection_result = detector.detect(test_image)
|
||||
self.assertEmpty(detection_result.detections)
|
||||
|
||||
def test_missing_result_callback(self):
|
||||
options = _FaceDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'result callback must be provided'
|
||||
):
|
||||
with _FaceDetector.create_from_options(options) as unused_detector:
|
||||
pass
|
||||
|
||||
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
|
||||
def test_illegal_result_callback(self, running_mode):
|
||||
options = _FaceDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=running_mode,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'result callback should not be provided'
|
||||
):
|
||||
with _FaceDetector.create_from_options(options) as unused_detector:
|
||||
pass
|
||||
|
||||
def test_calling_detect_for_video_in_image_mode(self):
|
||||
options = _FaceDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE,
|
||||
)
|
||||
with _FaceDetector.create_from_options(options) as detector:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the video mode'
|
||||
):
|
||||
detector.detect_for_video(self.test_image, 0)
|
||||
|
||||
def test_calling_detect_async_in_image_mode(self):
|
||||
options = _FaceDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE,
|
||||
)
|
||||
with _FaceDetector.create_from_options(options) as detector:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the live stream mode'
|
||||
):
|
||||
detector.detect_async(self.test_image, 0)
|
||||
|
||||
def test_calling_detect_in_video_mode(self):
|
||||
options = _FaceDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _FaceDetector.create_from_options(options) as detector:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the image mode'
|
||||
):
|
||||
detector.detect(self.test_image)
|
||||
|
||||
def test_calling_detect_async_in_video_mode(self):
|
||||
options = _FaceDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _FaceDetector.create_from_options(options) as detector:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the live stream mode'
|
||||
):
|
||||
detector.detect_async(self.test_image, 0)
|
||||
|
||||
def test_detect_for_video_with_out_of_order_timestamp(self):
|
||||
options = _FaceDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _FaceDetector.create_from_options(options) as detector:
|
||||
unused_result = detector.detect_for_video(self.test_image, 1)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'
|
||||
):
|
||||
detector.detect_for_video(self.test_image, 0)
|
||||
|
||||
@parameterized.parameters(
|
||||
(
|
||||
ModelFileType.FILE_NAME,
|
||||
_PORTRAIT_IMAGE,
|
||||
0,
|
||||
_get_expected_face_detector_result(_PORTRAIT_EXPECTED_DETECTION),
|
||||
),
|
||||
(
|
||||
ModelFileType.FILE_CONTENT,
|
||||
_PORTRAIT_IMAGE,
|
||||
0,
|
||||
_get_expected_face_detector_result(_PORTRAIT_EXPECTED_DETECTION),
|
||||
),
|
||||
(
|
||||
ModelFileType.FILE_NAME,
|
||||
_PORTRAIT_ROTATED_IMAGE,
|
||||
-90,
|
||||
_get_expected_face_detector_result(
|
||||
_PORTRAIT_ROTATED_EXPECTED_DETECTION
|
||||
),
|
||||
),
|
||||
(
|
||||
ModelFileType.FILE_CONTENT,
|
||||
_PORTRAIT_ROTATED_IMAGE,
|
||||
-90,
|
||||
_get_expected_face_detector_result(
|
||||
_PORTRAIT_ROTATED_EXPECTED_DETECTION
|
||||
),
|
||||
),
|
||||
(ModelFileType.FILE_NAME, _CAT_IMAGE, 0, FaceDetectorResult([])),
|
||||
(ModelFileType.FILE_CONTENT, _CAT_IMAGE, 0, FaceDetectorResult([])),
|
||||
)
|
||||
def test_detect_for_video(
|
||||
self,
|
||||
model_file_type,
|
||||
test_image_file_name,
|
||||
rotation_degrees,
|
||||
expected_detection_result,
|
||||
):
|
||||
# Creates detector.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _FaceDetectorOptions(
|
||||
base_options=base_options, running_mode=_RUNNING_MODE.VIDEO
|
||||
)
|
||||
|
||||
with _FaceDetector.create_from_options(options) as detector:
|
||||
for timestamp in range(0, 300, 30):
|
||||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, test_image_file_name)
|
||||
)
|
||||
)
|
||||
# Set the image processing options.
|
||||
image_processing_options = _ImageProcessingOptions(
|
||||
rotation_degrees=rotation_degrees
|
||||
)
|
||||
# Performs face detection on the input.
|
||||
detection_result = detector.detect_for_video(
|
||||
test_image, timestamp, image_processing_options
|
||||
)
|
||||
# Comparing results.
|
||||
self._expect_face_detector_results_correct(
|
||||
detection_result, expected_detection_result
|
||||
)
|
||||
|
||||
def test_calling_detect_in_live_stream_mode(self):
|
||||
options = _FaceDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _FaceDetector.create_from_options(options) as detector:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the image mode'
|
||||
):
|
||||
detector.detect(self.test_image)
|
||||
|
||||
def test_calling_detect_for_video_in_live_stream_mode(self):
|
||||
options = _FaceDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _FaceDetector.create_from_options(options) as detector:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the video mode'
|
||||
):
|
||||
detector.detect_for_video(self.test_image, 0)
|
||||
|
||||
def test_detect_async_calls_with_illegal_timestamp(self):
|
||||
options = _FaceDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _FaceDetector.create_from_options(options) as detector:
|
||||
detector.detect_async(self.test_image, 100)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'
|
||||
):
|
||||
detector.detect_async(self.test_image, 0)
|
||||
|
||||
@parameterized.parameters(
|
||||
(
|
||||
ModelFileType.FILE_NAME,
|
||||
_PORTRAIT_IMAGE,
|
||||
0,
|
||||
_get_expected_face_detector_result(_PORTRAIT_EXPECTED_DETECTION),
|
||||
),
|
||||
(
|
||||
ModelFileType.FILE_CONTENT,
|
||||
_PORTRAIT_IMAGE,
|
||||
0,
|
||||
_get_expected_face_detector_result(_PORTRAIT_EXPECTED_DETECTION),
|
||||
),
|
||||
(
|
||||
ModelFileType.FILE_NAME,
|
||||
_PORTRAIT_ROTATED_IMAGE,
|
||||
-90,
|
||||
_get_expected_face_detector_result(
|
||||
_PORTRAIT_ROTATED_EXPECTED_DETECTION
|
||||
),
|
||||
),
|
||||
(
|
||||
ModelFileType.FILE_CONTENT,
|
||||
_PORTRAIT_ROTATED_IMAGE,
|
||||
-90,
|
||||
_get_expected_face_detector_result(
|
||||
_PORTRAIT_ROTATED_EXPECTED_DETECTION
|
||||
),
|
||||
),
|
||||
(ModelFileType.FILE_NAME, _CAT_IMAGE, 0, FaceDetectorResult([])),
|
||||
(ModelFileType.FILE_CONTENT, _CAT_IMAGE, 0, FaceDetectorResult([])),
|
||||
)
|
||||
def test_detect_async_calls(
|
||||
self,
|
||||
model_file_type,
|
||||
test_image_file_name,
|
||||
rotation_degrees,
|
||||
expected_detection_result,
|
||||
):
|
||||
# Creates detector.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
observed_timestamp_ms = -1
|
||||
|
||||
def check_result(
|
||||
result: FaceDetectorResult,
|
||||
unused_output_image: _Image,
|
||||
timestamp_ms: int,
|
||||
):
|
||||
self._expect_face_detector_results_correct(
|
||||
result, expected_detection_result
|
||||
)
|
||||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||
self.observed_timestamp_ms = timestamp_ms
|
||||
|
||||
options = _FaceDetectorOptions(
|
||||
base_options=base_options,
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=check_result,
|
||||
)
|
||||
|
||||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, test_image_file_name)
|
||||
)
|
||||
)
|
||||
|
||||
with _FaceDetector.create_from_options(options) as detector:
|
||||
for timestamp in range(0, 300, 30):
|
||||
# Set the image processing options.
|
||||
image_processing_options = _ImageProcessingOptions(
|
||||
rotation_degrees=rotation_degrees
|
||||
)
|
||||
detector.detect_async(test_image, timestamp, image_processing_options)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
+565
@@ -0,0 +1,565 @@
|
||||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for face landmarker."""
|
||||
|
||||
import enum
|
||||
from unittest import mock
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from google.protobuf import text_format
|
||||
from mediapipe.framework.formats import classification_pb2
|
||||
from mediapipe.framework.formats import landmark_pb2
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.tasks.python.components.containers import category as category_module
|
||||
from mediapipe.tasks.python.components.containers import landmark as landmark_module
|
||||
from mediapipe.tasks.python.components.containers import rect as rect_module
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
from mediapipe.tasks.python.vision import face_landmarker
|
||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||
|
||||
|
||||
FaceLandmarkerResult = face_landmarker.FaceLandmarkerResult
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_Category = category_module.Category
|
||||
_Rect = rect_module.Rect
|
||||
_Landmark = landmark_module.Landmark
|
||||
_NormalizedLandmark = landmark_module.NormalizedLandmark
|
||||
_Image = image_module.Image
|
||||
_FaceLandmarker = face_landmarker.FaceLandmarker
|
||||
_FaceLandmarkerOptions = face_landmarker.FaceLandmarkerOptions
|
||||
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
|
||||
_FACE_LANDMARKER_BUNDLE_ASSET_FILE = 'face_landmarker_v2.task'
|
||||
_PORTRAIT_IMAGE = 'portrait.jpg'
|
||||
_CAT_IMAGE = 'cat.jpg'
|
||||
_PORTRAIT_EXPECTED_FACE_LANDMARKS = 'portrait_expected_face_landmarks.pbtxt'
|
||||
_PORTRAIT_EXPECTED_BLENDSHAPES = 'portrait_expected_blendshapes.pbtxt'
|
||||
_LANDMARKS_MARGIN = 0.03
|
||||
_BLENDSHAPES_MARGIN = 0.13
|
||||
_FACIAL_TRANSFORMATION_MATRIX_MARGIN = 0.02
|
||||
|
||||
|
||||
def _get_expected_face_landmarks(file_path: str):
|
||||
proto_file_path = test_utils.get_test_data_path(file_path)
|
||||
face_landmarks_results = []
|
||||
with open(proto_file_path, 'rb') as f:
|
||||
proto = landmark_pb2.NormalizedLandmarkList()
|
||||
text_format.Parse(f.read(), proto)
|
||||
face_landmarks = []
|
||||
for landmark in proto.landmark:
|
||||
face_landmarks.append(_NormalizedLandmark.create_from_pb2(landmark))
|
||||
face_landmarks_results.append(face_landmarks)
|
||||
return face_landmarks_results
|
||||
|
||||
|
||||
def _get_expected_face_blendshapes(file_path: str):
|
||||
proto_file_path = test_utils.get_test_data_path(file_path)
|
||||
face_blendshapes_results = []
|
||||
with open(proto_file_path, 'rb') as f:
|
||||
proto = classification_pb2.ClassificationList()
|
||||
text_format.Parse(f.read(), proto)
|
||||
face_blendshapes_categories = []
|
||||
face_blendshapes_classifications = classification_pb2.ClassificationList()
|
||||
face_blendshapes_classifications.MergeFrom(proto)
|
||||
for face_blendshapes in face_blendshapes_classifications.classification:
|
||||
face_blendshapes_categories.append(
|
||||
category_module.Category(
|
||||
index=face_blendshapes.index,
|
||||
score=face_blendshapes.score,
|
||||
display_name=face_blendshapes.display_name,
|
||||
category_name=face_blendshapes.label,
|
||||
)
|
||||
)
|
||||
face_blendshapes_results.append(face_blendshapes_categories)
|
||||
return face_blendshapes_results
|
||||
|
||||
|
||||
def _get_expected_facial_transformation_matrixes():
|
||||
matrix = np.array([
|
||||
[0.9995292, -0.01294756, 0.038823195, -0.3691378],
|
||||
[0.0072318087, 0.9937692, -0.1101321, 22.75809],
|
||||
[-0.03715533, 0.11070588, 0.99315894, -65.765925],
|
||||
[0, 0, 0, 1],
|
||||
])
|
||||
facial_transformation_matrixes_results = []
|
||||
facial_transformation_matrixes_results.append(matrix)
|
||||
return facial_transformation_matrixes_results
|
||||
|
||||
|
||||
class ModelFileType(enum.Enum):
|
||||
FILE_CONTENT = 1
|
||||
FILE_NAME = 2
|
||||
|
||||
|
||||
class FaceLandmarkerTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(_PORTRAIT_IMAGE)
|
||||
)
|
||||
self.model_path = test_utils.get_test_data_path(
|
||||
_FACE_LANDMARKER_BUNDLE_ASSET_FILE
|
||||
)
|
||||
|
||||
def _expect_landmarks_correct(self, actual_landmarks, expected_landmarks):
|
||||
# Expects to have the same number of faces detected.
|
||||
self.assertLen(actual_landmarks, len(expected_landmarks))
|
||||
|
||||
for i, _ in enumerate(actual_landmarks):
|
||||
for j, elem in enumerate(actual_landmarks[i]):
|
||||
self.assertAlmostEqual(
|
||||
elem.x, expected_landmarks[i][j].x, delta=_LANDMARKS_MARGIN
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
elem.y, expected_landmarks[i][j].y, delta=_LANDMARKS_MARGIN
|
||||
)
|
||||
|
||||
def _expect_blendshapes_correct(
|
||||
self, actual_blendshapes, expected_blendshapes
|
||||
):
|
||||
# Expects to have the same number of blendshapes.
|
||||
self.assertLen(actual_blendshapes, len(expected_blendshapes))
|
||||
|
||||
for i, _ in enumerate(actual_blendshapes):
|
||||
for j, elem in enumerate(actual_blendshapes[i]):
|
||||
self.assertEqual(elem.index, expected_blendshapes[i][j].index)
|
||||
self.assertAlmostEqual(
|
||||
elem.score,
|
||||
expected_blendshapes[i][j].score,
|
||||
delta=_BLENDSHAPES_MARGIN,
|
||||
)
|
||||
|
||||
def _expect_facial_transformation_matrixes_correct(
|
||||
self, actual_matrix_list, expected_matrix_list
|
||||
):
|
||||
self.assertLen(actual_matrix_list, len(expected_matrix_list))
|
||||
|
||||
for i, elem in enumerate(actual_matrix_list):
|
||||
self.assertEqual(elem.shape[0], expected_matrix_list[i].shape[0])
|
||||
self.assertEqual(elem.shape[1], expected_matrix_list[i].shape[1])
|
||||
self.assertSequenceAlmostEqual(
|
||||
elem.flatten(),
|
||||
expected_matrix_list[i].flatten(),
|
||||
delta=_FACIAL_TRANSFORMATION_MATRIX_MARGIN,
|
||||
)
|
||||
|
||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||
# Creates with default option and valid model file successfully.
|
||||
with _FaceLandmarker.create_from_model_path(self.model_path) as landmarker:
|
||||
self.assertIsInstance(landmarker, _FaceLandmarker)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_path(self):
|
||||
# Creates with options containing model file successfully.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _FaceLandmarkerOptions(base_options=base_options)
|
||||
with _FaceLandmarker.create_from_options(options) as landmarker:
|
||||
self.assertIsInstance(landmarker, _FaceLandmarker)
|
||||
|
||||
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||
# Invalid empty model path.
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'
|
||||
):
|
||||
base_options = _BaseOptions(
|
||||
model_asset_path='/path/to/invalid/model.tflite'
|
||||
)
|
||||
options = _FaceLandmarkerOptions(base_options=base_options)
|
||||
_FaceLandmarker.create_from_options(options)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_content(self):
|
||||
# Creates with options containing model content successfully.
|
||||
with open(self.model_path, 'rb') as f:
|
||||
base_options = _BaseOptions(model_asset_buffer=f.read())
|
||||
options = _FaceLandmarkerOptions(base_options=base_options)
|
||||
landmarker = _FaceLandmarker.create_from_options(options)
|
||||
self.assertIsInstance(landmarker, _FaceLandmarker)
|
||||
|
||||
@parameterized.parameters(
|
||||
(
|
||||
ModelFileType.FILE_NAME,
|
||||
_FACE_LANDMARKER_BUNDLE_ASSET_FILE,
|
||||
_get_expected_face_landmarks(_PORTRAIT_EXPECTED_FACE_LANDMARKS),
|
||||
None,
|
||||
None,
|
||||
),
|
||||
(
|
||||
ModelFileType.FILE_CONTENT,
|
||||
_FACE_LANDMARKER_BUNDLE_ASSET_FILE,
|
||||
_get_expected_face_landmarks(_PORTRAIT_EXPECTED_FACE_LANDMARKS),
|
||||
None,
|
||||
None,
|
||||
),
|
||||
)
|
||||
def test_detect(
|
||||
self,
|
||||
model_file_type,
|
||||
model_name,
|
||||
expected_face_landmarks,
|
||||
expected_face_blendshapes,
|
||||
expected_facial_transformation_matrixes,
|
||||
):
|
||||
# Creates face landmarker.
|
||||
model_path = test_utils.get_test_data_path(model_name)
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _FaceLandmarkerOptions(
|
||||
base_options=base_options,
|
||||
output_face_blendshapes=True if expected_face_blendshapes else False,
|
||||
output_facial_transformation_matrixes=True
|
||||
if expected_facial_transformation_matrixes
|
||||
else False,
|
||||
)
|
||||
landmarker = _FaceLandmarker.create_from_options(options)
|
||||
|
||||
# Performs face landmarks detection on the input.
|
||||
detection_result = landmarker.detect(self.test_image)
|
||||
# Comparing results.
|
||||
if expected_face_landmarks is not None:
|
||||
self._expect_landmarks_correct(
|
||||
detection_result.face_landmarks, expected_face_landmarks
|
||||
)
|
||||
if expected_face_blendshapes is not None:
|
||||
self._expect_blendshapes_correct(
|
||||
detection_result.face_blendshapes, expected_face_blendshapes
|
||||
)
|
||||
if expected_facial_transformation_matrixes is not None:
|
||||
self._expect_facial_transformation_matrixes_correct(
|
||||
detection_result.facial_transformation_matrixes,
|
||||
expected_facial_transformation_matrixes,
|
||||
)
|
||||
|
||||
# Closes the face landmarker explicitly when the face landmarker is not used
|
||||
# in a context.
|
||||
landmarker.close()
|
||||
|
||||
@parameterized.parameters(
|
||||
(
|
||||
ModelFileType.FILE_NAME,
|
||||
_FACE_LANDMARKER_BUNDLE_ASSET_FILE,
|
||||
_get_expected_face_landmarks(_PORTRAIT_EXPECTED_FACE_LANDMARKS),
|
||||
None,
|
||||
None,
|
||||
),
|
||||
(
|
||||
ModelFileType.FILE_CONTENT,
|
||||
_FACE_LANDMARKER_BUNDLE_ASSET_FILE,
|
||||
_get_expected_face_landmarks(_PORTRAIT_EXPECTED_FACE_LANDMARKS),
|
||||
None,
|
||||
None,
|
||||
),
|
||||
)
|
||||
def test_detect_in_context(
|
||||
self,
|
||||
model_file_type,
|
||||
model_name,
|
||||
expected_face_landmarks,
|
||||
expected_face_blendshapes,
|
||||
expected_facial_transformation_matrixes,
|
||||
):
|
||||
# Creates face landmarker.
|
||||
model_path = test_utils.get_test_data_path(model_name)
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _FaceLandmarkerOptions(
|
||||
base_options=base_options,
|
||||
output_face_blendshapes=True if expected_face_blendshapes else False,
|
||||
output_facial_transformation_matrixes=True
|
||||
if expected_facial_transformation_matrixes
|
||||
else False,
|
||||
)
|
||||
|
||||
with _FaceLandmarker.create_from_options(options) as landmarker:
|
||||
# Performs face landmarks detection on the input.
|
||||
detection_result = landmarker.detect(self.test_image)
|
||||
# Comparing results.
|
||||
if expected_face_landmarks is not None:
|
||||
self._expect_landmarks_correct(
|
||||
detection_result.face_landmarks, expected_face_landmarks
|
||||
)
|
||||
if expected_face_blendshapes is not None:
|
||||
self._expect_blendshapes_correct(
|
||||
detection_result.face_blendshapes, expected_face_blendshapes
|
||||
)
|
||||
if expected_facial_transformation_matrixes is not None:
|
||||
self._expect_facial_transformation_matrixes_correct(
|
||||
detection_result.facial_transformation_matrixes,
|
||||
expected_facial_transformation_matrixes,
|
||||
)
|
||||
|
||||
def test_empty_detection_outputs(self):
|
||||
options = _FaceLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path)
|
||||
)
|
||||
with _FaceLandmarker.create_from_options(options) as landmarker:
|
||||
# Load the image with no faces.
|
||||
no_faces_test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(_CAT_IMAGE)
|
||||
)
|
||||
# Performs face landmarks detection on the input.
|
||||
detection_result = landmarker.detect(no_faces_test_image)
|
||||
self.assertEmpty(detection_result.face_landmarks)
|
||||
self.assertEmpty(detection_result.face_blendshapes)
|
||||
self.assertEmpty(detection_result.facial_transformation_matrixes)
|
||||
|
||||
def test_missing_result_callback(self):
|
||||
options = _FaceLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'result callback must be provided'
|
||||
):
|
||||
with _FaceLandmarker.create_from_options(options) as unused_landmarker:
|
||||
pass
|
||||
|
||||
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
|
||||
def test_illegal_result_callback(self, running_mode):
|
||||
options = _FaceLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=running_mode,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'result callback should not be provided'
|
||||
):
|
||||
with _FaceLandmarker.create_from_options(options) as unused_landmarker:
|
||||
pass
|
||||
|
||||
def test_calling_detect_for_video_in_image_mode(self):
|
||||
options = _FaceLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE,
|
||||
)
|
||||
with _FaceLandmarker.create_from_options(options) as landmarker:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the video mode'
|
||||
):
|
||||
landmarker.detect_for_video(self.test_image, 0)
|
||||
|
||||
def test_calling_detect_async_in_image_mode(self):
|
||||
options = _FaceLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE,
|
||||
)
|
||||
with _FaceLandmarker.create_from_options(options) as landmarker:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the live stream mode'
|
||||
):
|
||||
landmarker.detect_async(self.test_image, 0)
|
||||
|
||||
def test_calling_detect_in_video_mode(self):
|
||||
options = _FaceLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _FaceLandmarker.create_from_options(options) as landmarker:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the image mode'
|
||||
):
|
||||
landmarker.detect(self.test_image)
|
||||
|
||||
def test_calling_detect_async_in_video_mode(self):
|
||||
options = _FaceLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _FaceLandmarker.create_from_options(options) as landmarker:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the live stream mode'
|
||||
):
|
||||
landmarker.detect_async(self.test_image, 0)
|
||||
|
||||
def test_detect_for_video_with_out_of_order_timestamp(self):
|
||||
options = _FaceLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _FaceLandmarker.create_from_options(options) as landmarker:
|
||||
unused_result = landmarker.detect_for_video(self.test_image, 1)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'
|
||||
):
|
||||
landmarker.detect_for_video(self.test_image, 0)
|
||||
|
||||
@parameterized.parameters(
|
||||
(
|
||||
_FACE_LANDMARKER_BUNDLE_ASSET_FILE,
|
||||
_get_expected_face_landmarks(_PORTRAIT_EXPECTED_FACE_LANDMARKS),
|
||||
None,
|
||||
None,
|
||||
),
|
||||
)
|
||||
def test_detect_for_video(
|
||||
self,
|
||||
model_name,
|
||||
expected_face_landmarks,
|
||||
expected_face_blendshapes,
|
||||
expected_facial_transformation_matrixes,
|
||||
):
|
||||
# Creates face landmarker.
|
||||
model_path = test_utils.get_test_data_path(model_name)
|
||||
base_options = _BaseOptions(model_asset_path=model_path)
|
||||
|
||||
options = _FaceLandmarkerOptions(
|
||||
base_options=base_options,
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
output_face_blendshapes=True if expected_face_blendshapes else False,
|
||||
output_facial_transformation_matrixes=True
|
||||
if expected_facial_transformation_matrixes
|
||||
else False,
|
||||
)
|
||||
|
||||
with _FaceLandmarker.create_from_options(options) as landmarker:
|
||||
for timestamp in range(0, 300, 30):
|
||||
# Performs face landmarks detection on the input.
|
||||
detection_result = landmarker.detect_for_video(
|
||||
self.test_image, timestamp
|
||||
)
|
||||
# Comparing results.
|
||||
if expected_face_landmarks is not None:
|
||||
self._expect_landmarks_correct(
|
||||
detection_result.face_landmarks, expected_face_landmarks
|
||||
)
|
||||
if expected_face_blendshapes is not None:
|
||||
self._expect_blendshapes_correct(
|
||||
detection_result.face_blendshapes, expected_face_blendshapes
|
||||
)
|
||||
if expected_facial_transformation_matrixes is not None:
|
||||
self._expect_facial_transformation_matrixes_correct(
|
||||
detection_result.facial_transformation_matrixes,
|
||||
expected_facial_transformation_matrixes,
|
||||
)
|
||||
|
||||
def test_calling_detect_in_live_stream_mode(self):
|
||||
options = _FaceLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _FaceLandmarker.create_from_options(options) as landmarker:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the image mode'
|
||||
):
|
||||
landmarker.detect(self.test_image)
|
||||
|
||||
def test_calling_detect_for_video_in_live_stream_mode(self):
|
||||
options = _FaceLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _FaceLandmarker.create_from_options(options) as landmarker:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the video mode'
|
||||
):
|
||||
landmarker.detect_for_video(self.test_image, 0)
|
||||
|
||||
def test_detect_async_calls_with_illegal_timestamp(self):
|
||||
options = _FaceLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _FaceLandmarker.create_from_options(options) as landmarker:
|
||||
landmarker.detect_async(self.test_image, 100)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'
|
||||
):
|
||||
landmarker.detect_async(self.test_image, 0)
|
||||
|
||||
@parameterized.parameters(
|
||||
(
|
||||
_PORTRAIT_IMAGE,
|
||||
_FACE_LANDMARKER_BUNDLE_ASSET_FILE,
|
||||
_get_expected_face_landmarks(_PORTRAIT_EXPECTED_FACE_LANDMARKS),
|
||||
None,
|
||||
None,
|
||||
),
|
||||
)
|
||||
def test_detect_async_calls(
|
||||
self,
|
||||
image_path,
|
||||
model_name,
|
||||
expected_face_landmarks,
|
||||
expected_face_blendshapes,
|
||||
expected_facial_transformation_matrixes,
|
||||
):
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(image_path)
|
||||
)
|
||||
observed_timestamp_ms = -1
|
||||
|
||||
def check_result(
|
||||
result: FaceLandmarkerResult, output_image: _Image, timestamp_ms: int
|
||||
):
|
||||
# Comparing results.
|
||||
if expected_face_landmarks is not None:
|
||||
self._expect_landmarks_correct(
|
||||
result.face_landmarks, expected_face_landmarks
|
||||
)
|
||||
if expected_face_blendshapes is not None:
|
||||
self._expect_blendshapes_correct(
|
||||
result.face_blendshapes, expected_face_blendshapes
|
||||
)
|
||||
if expected_facial_transformation_matrixes is not None:
|
||||
self._expect_facial_transformation_matrixes_correct(
|
||||
result.facial_transformation_matrixes,
|
||||
expected_facial_transformation_matrixes,
|
||||
)
|
||||
self.assertTrue(
|
||||
np.array_equal(output_image.numpy_view(), test_image.numpy_view())
|
||||
)
|
||||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||
self.observed_timestamp_ms = timestamp_ms
|
||||
|
||||
model_path = test_utils.get_test_data_path(model_name)
|
||||
options = _FaceLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
output_face_blendshapes=True if expected_face_blendshapes else False,
|
||||
output_facial_transformation_matrixes=True
|
||||
if expected_facial_transformation_matrixes
|
||||
else False,
|
||||
result_callback=check_result,
|
||||
)
|
||||
with _FaceLandmarker.create_from_options(options) as landmarker:
|
||||
for timestamp in range(0, 300, 30):
|
||||
landmarker.detect_async(test_image, timestamp)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
+191
@@ -0,0 +1,191 @@
|
||||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for face stylizer."""
|
||||
|
||||
import enum
|
||||
import os
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.tasks.python.components.containers import rect
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
from mediapipe.tasks.python.vision import face_stylizer
|
||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
|
||||
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_Rect = rect.Rect
|
||||
_Image = image_module.Image
|
||||
_FaceStylizer = face_stylizer.FaceStylizer
|
||||
_FaceStylizerOptions = face_stylizer.FaceStylizerOptions
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
|
||||
_MODEL = 'face_stylizer_color_ink.task'
|
||||
_LARGE_FACE_IMAGE = 'portrait.jpg'
|
||||
_MODEL_IMAGE_SIZE = 256
|
||||
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
||||
|
||||
|
||||
class ModelFileType(enum.Enum):
|
||||
FILE_CONTENT = 1
|
||||
FILE_NAME = 2
|
||||
|
||||
|
||||
class FaceStylizerTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _LARGE_FACE_IMAGE)
|
||||
)
|
||||
)
|
||||
self.model_path = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _MODEL)
|
||||
)
|
||||
|
||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||
# Creates with default option and valid model file successfully.
|
||||
with _FaceStylizer.create_from_model_path(self.model_path) as stylizer:
|
||||
self.assertIsInstance(stylizer, _FaceStylizer)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_path(self):
|
||||
# Creates with options containing model file successfully.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _FaceStylizerOptions(base_options=base_options)
|
||||
with _FaceStylizer.create_from_options(options) as stylizer:
|
||||
self.assertIsInstance(stylizer, _FaceStylizer)
|
||||
|
||||
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'
|
||||
):
|
||||
base_options = _BaseOptions(
|
||||
model_asset_path='/path/to/invalid/model.tflite'
|
||||
)
|
||||
options = _FaceStylizerOptions(base_options=base_options)
|
||||
_FaceStylizer.create_from_options(options)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_content(self):
|
||||
# Creates with options containing model content successfully.
|
||||
with open(self.model_path, 'rb') as f:
|
||||
base_options = _BaseOptions(model_asset_buffer=f.read())
|
||||
options = _FaceStylizerOptions(base_options=base_options)
|
||||
stylizer = _FaceStylizer.create_from_options(options)
|
||||
self.assertIsInstance(stylizer, _FaceStylizer)
|
||||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME, _LARGE_FACE_IMAGE),
|
||||
(ModelFileType.FILE_CONTENT, _LARGE_FACE_IMAGE),
|
||||
)
|
||||
def test_stylize(self, model_file_type, image_file_name):
|
||||
# Load the test image.
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, image_file_name)
|
||||
)
|
||||
)
|
||||
# Creates stylizer.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _FaceStylizerOptions(base_options=base_options)
|
||||
stylizer = _FaceStylizer.create_from_options(options)
|
||||
|
||||
# Performs face stylization on the input.
|
||||
stylized_image = stylizer.stylize(self.test_image)
|
||||
self.assertIsInstance(stylized_image, _Image)
|
||||
# Closes the stylizer explicitly when the stylizer is not used in
|
||||
# a context.
|
||||
stylizer.close()
|
||||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME, _LARGE_FACE_IMAGE),
|
||||
(ModelFileType.FILE_CONTENT, _LARGE_FACE_IMAGE),
|
||||
)
|
||||
def test_stylize_in_context(self, model_file_type, image_file_name):
|
||||
# Load the test image.
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, image_file_name)
|
||||
)
|
||||
)
|
||||
# Creates stylizer.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _FaceStylizerOptions(base_options=base_options)
|
||||
with _FaceStylizer.create_from_options(options) as stylizer:
|
||||
# Performs face stylization on the input.
|
||||
stylized_image = stylizer.stylize(self.test_image)
|
||||
self.assertIsInstance(stylized_image, _Image)
|
||||
self.assertEqual(stylized_image.width, _MODEL_IMAGE_SIZE)
|
||||
self.assertEqual(stylized_image.height, _MODEL_IMAGE_SIZE)
|
||||
|
||||
def test_stylize_succeeds_with_region_of_interest(self):
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _FaceStylizerOptions(base_options=base_options)
|
||||
with _FaceStylizer.create_from_options(options) as stylizer:
|
||||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _LARGE_FACE_IMAGE)
|
||||
)
|
||||
)
|
||||
# Region-of-interest around the face.
|
||||
roi = _Rect(left=0.32, top=0.02, right=0.67, bottom=0.32)
|
||||
image_processing_options = _ImageProcessingOptions(roi)
|
||||
# Performs face stylization on the input.
|
||||
stylized_image = stylizer.stylize(test_image, image_processing_options)
|
||||
self.assertIsInstance(stylized_image, _Image)
|
||||
self.assertEqual(stylized_image.width, _MODEL_IMAGE_SIZE)
|
||||
self.assertEqual(stylized_image.height, _MODEL_IMAGE_SIZE)
|
||||
|
||||
def test_stylize_succeeds_with_no_face_detected(self):
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _FaceStylizerOptions(base_options=base_options)
|
||||
with _FaceStylizer.create_from_options(options) as stylizer:
|
||||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _LARGE_FACE_IMAGE)
|
||||
)
|
||||
)
|
||||
# Region-of-interest that doesn't contain a human face.
|
||||
roi = _Rect(left=0.1, top=0.1, right=0.2, bottom=0.2)
|
||||
image_processing_options = _ImageProcessingOptions(roi)
|
||||
# Performs face stylization on the input.
|
||||
stylized_image = stylizer.stylize(test_image, image_processing_options)
|
||||
self.assertIsNone(stylized_image)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
+437
@@ -0,0 +1,437 @@
|
||||
# Copyright 2022 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for hand landmarker."""
|
||||
|
||||
import enum
|
||||
from unittest import mock
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from google.protobuf import text_format
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.tasks.cc.components.containers.proto import landmarks_detection_result_pb2
|
||||
from mediapipe.tasks.python.components.containers import landmark as landmark_module
|
||||
from mediapipe.tasks.python.components.containers import landmark_detection_result as landmark_detection_result_module
|
||||
from mediapipe.tasks.python.components.containers import rect as rect_module
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
from mediapipe.tasks.python.vision import hand_landmarker
|
||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||
|
||||
_LandmarksDetectionResultProto = (
|
||||
landmarks_detection_result_pb2.LandmarksDetectionResult)
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_Rect = rect_module.Rect
|
||||
_Landmark = landmark_module.Landmark
|
||||
_NormalizedLandmark = landmark_module.NormalizedLandmark
|
||||
_LandmarksDetectionResult = (
|
||||
landmark_detection_result_module.LandmarksDetectionResult)
|
||||
_Image = image_module.Image
|
||||
_HandLandmarker = hand_landmarker.HandLandmarker
|
||||
_HandLandmarkerOptions = hand_landmarker.HandLandmarkerOptions
|
||||
_HandLandmarkerResult = hand_landmarker.HandLandmarkerResult
|
||||
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
|
||||
_HAND_LANDMARKER_BUNDLE_ASSET_FILE = 'hand_landmarker.task'
|
||||
_NO_HANDS_IMAGE = 'cats_and_dogs.jpg'
|
||||
_TWO_HANDS_IMAGE = 'right_hands.jpg'
|
||||
_THUMB_UP_IMAGE = 'thumb_up.jpg'
|
||||
_THUMB_UP_LANDMARKS = 'thumb_up_landmarks.pbtxt'
|
||||
_POINTING_UP_IMAGE = 'pointing_up.jpg'
|
||||
_POINTING_UP_LANDMARKS = 'pointing_up_landmarks.pbtxt'
|
||||
_POINTING_UP_ROTATED_IMAGE = 'pointing_up_rotated.jpg'
|
||||
_POINTING_UP_ROTATED_LANDMARKS = 'pointing_up_rotated_landmarks.pbtxt'
|
||||
_LANDMARKS_MARGIN = 0.03
|
||||
_HANDEDNESS_MARGIN = 0.05
|
||||
|
||||
|
||||
def _get_expected_hand_landmarker_result(
|
||||
file_path: str) -> _HandLandmarkerResult:
|
||||
landmarks_detection_result_file_path = test_utils.get_test_data_path(
|
||||
file_path)
|
||||
with open(landmarks_detection_result_file_path, 'rb') as f:
|
||||
landmarks_detection_result_proto = _LandmarksDetectionResultProto()
|
||||
# Use this if a .pb file is available.
|
||||
# landmarks_detection_result_proto.ParseFromString(f.read())
|
||||
text_format.Parse(f.read(), landmarks_detection_result_proto)
|
||||
landmarks_detection_result = _LandmarksDetectionResult.create_from_pb2(
|
||||
landmarks_detection_result_proto)
|
||||
return _HandLandmarkerResult(
|
||||
handedness=[landmarks_detection_result.categories],
|
||||
hand_landmarks=[landmarks_detection_result.landmarks],
|
||||
hand_world_landmarks=[landmarks_detection_result.world_landmarks])
|
||||
|
||||
|
||||
class ModelFileType(enum.Enum):
|
||||
FILE_CONTENT = 1
|
||||
FILE_NAME = 2
|
||||
|
||||
|
||||
class HandLandmarkerTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(_THUMB_UP_IMAGE))
|
||||
self.model_path = test_utils.get_test_data_path(
|
||||
_HAND_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||
|
||||
def _expect_hand_landmarks_correct(
|
||||
self, actual_landmarks, expected_landmarks, margin
|
||||
):
|
||||
# Expects to have the same number of hands detected.
|
||||
self.assertLen(actual_landmarks, len(expected_landmarks))
|
||||
|
||||
for i, _ in enumerate(actual_landmarks):
|
||||
for j, elem in enumerate(actual_landmarks[i]):
|
||||
self.assertAlmostEqual(elem.x, expected_landmarks[i][j].x, delta=margin)
|
||||
self.assertAlmostEqual(elem.y, expected_landmarks[i][j].y, delta=margin)
|
||||
|
||||
def _expect_handedness_correct(
|
||||
self, actual_handedness, expected_handedness, margin
|
||||
):
|
||||
# Actual top handedness matches expected top handedness.
|
||||
actual_top_handedness = actual_handedness[0][0]
|
||||
expected_top_handedness = expected_handedness[0][0]
|
||||
self.assertEqual(actual_top_handedness.index, expected_top_handedness.index)
|
||||
self.assertEqual(actual_top_handedness.category_name,
|
||||
expected_top_handedness.category_name)
|
||||
self.assertAlmostEqual(
|
||||
actual_top_handedness.score, expected_top_handedness.score, delta=margin
|
||||
)
|
||||
|
||||
def _expect_hand_landmarker_results_correct(
|
||||
self,
|
||||
actual_result: _HandLandmarkerResult,
|
||||
expected_result: _HandLandmarkerResult,
|
||||
):
|
||||
self._expect_hand_landmarks_correct(
|
||||
actual_result.hand_landmarks,
|
||||
expected_result.hand_landmarks,
|
||||
_LANDMARKS_MARGIN,
|
||||
)
|
||||
self._expect_handedness_correct(
|
||||
actual_result.handedness, expected_result.handedness, _HANDEDNESS_MARGIN
|
||||
)
|
||||
|
||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||
# Creates with default option and valid model file successfully.
|
||||
with _HandLandmarker.create_from_model_path(self.model_path) as landmarker:
|
||||
self.assertIsInstance(landmarker, _HandLandmarker)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_path(self):
|
||||
# Creates with options containing model file successfully.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _HandLandmarkerOptions(base_options=base_options)
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
self.assertIsInstance(landmarker, _HandLandmarker)
|
||||
|
||||
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||
# Invalid empty model path.
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
|
||||
base_options = _BaseOptions(
|
||||
model_asset_path='/path/to/invalid/model.tflite')
|
||||
options = _HandLandmarkerOptions(base_options=base_options)
|
||||
_HandLandmarker.create_from_options(options)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_content(self):
|
||||
# Creates with options containing model content successfully.
|
||||
with open(self.model_path, 'rb') as f:
|
||||
base_options = _BaseOptions(model_asset_buffer=f.read())
|
||||
options = _HandLandmarkerOptions(base_options=base_options)
|
||||
landmarker = _HandLandmarker.create_from_options(options)
|
||||
self.assertIsInstance(landmarker, _HandLandmarker)
|
||||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME,
|
||||
_get_expected_hand_landmarker_result(_THUMB_UP_LANDMARKS)),
|
||||
(ModelFileType.FILE_CONTENT,
|
||||
_get_expected_hand_landmarker_result(_THUMB_UP_LANDMARKS)))
|
||||
def test_detect(self, model_file_type, expected_detection_result):
|
||||
# Creates hand landmarker.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _HandLandmarkerOptions(base_options=base_options)
|
||||
landmarker = _HandLandmarker.create_from_options(options)
|
||||
|
||||
# Performs hand landmarks detection on the input.
|
||||
detection_result = landmarker.detect(self.test_image)
|
||||
# Comparing results.
|
||||
self._expect_hand_landmarker_results_correct(
|
||||
detection_result, expected_detection_result
|
||||
)
|
||||
# Closes the hand landmarker explicitly when the hand landmarker is not used
|
||||
# in a context.
|
||||
landmarker.close()
|
||||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME,
|
||||
_get_expected_hand_landmarker_result(_THUMB_UP_LANDMARKS)),
|
||||
(ModelFileType.FILE_CONTENT,
|
||||
_get_expected_hand_landmarker_result(_THUMB_UP_LANDMARKS)))
|
||||
def test_detect_in_context(self, model_file_type, expected_detection_result):
|
||||
# Creates hand landmarker.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _HandLandmarkerOptions(base_options=base_options)
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
# Performs hand landmarks detection on the input.
|
||||
detection_result = landmarker.detect(self.test_image)
|
||||
# Comparing results.
|
||||
self._expect_hand_landmarker_results_correct(
|
||||
detection_result, expected_detection_result
|
||||
)
|
||||
|
||||
def test_detect_succeeds_with_num_hands(self):
|
||||
# Creates hand landmarker.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _HandLandmarkerOptions(base_options=base_options, num_hands=2)
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
# Load the two hands image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(_TWO_HANDS_IMAGE))
|
||||
# Performs hand landmarks detection on the input.
|
||||
detection_result = landmarker.detect(test_image)
|
||||
# Comparing results.
|
||||
self.assertLen(detection_result.handedness, 2)
|
||||
|
||||
def test_detect_succeeds_with_rotation(self):
|
||||
# Creates hand landmarker.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _HandLandmarkerOptions(base_options=base_options)
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
# Load the pointing up rotated image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(_POINTING_UP_ROTATED_IMAGE))
|
||||
# Set rotation parameters using ImageProcessingOptions.
|
||||
image_processing_options = _ImageProcessingOptions(rotation_degrees=-90)
|
||||
# Performs hand landmarks detection on the input.
|
||||
detection_result = landmarker.detect(test_image, image_processing_options)
|
||||
expected_detection_result = _get_expected_hand_landmarker_result(
|
||||
_POINTING_UP_ROTATED_LANDMARKS)
|
||||
# Comparing results.
|
||||
self._expect_hand_landmarker_results_correct(
|
||||
detection_result, expected_detection_result
|
||||
)
|
||||
|
||||
def test_detect_fails_with_region_of_interest(self):
|
||||
# Creates hand landmarker.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _HandLandmarkerOptions(base_options=base_options)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "This task doesn't support region-of-interest."):
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
# Set the `region_of_interest` parameter using `ImageProcessingOptions`.
|
||||
image_processing_options = _ImageProcessingOptions(
|
||||
region_of_interest=_Rect(0, 0, 1, 1))
|
||||
# Attempt to perform hand landmarks detection on the cropped input.
|
||||
landmarker.detect(self.test_image, image_processing_options)
|
||||
|
||||
def test_empty_detection_outputs(self):
|
||||
options = _HandLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path))
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
# Load the image with no hands.
|
||||
no_hands_test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(_NO_HANDS_IMAGE))
|
||||
# Performs hand landmarks detection on the input.
|
||||
detection_result = landmarker.detect(no_hands_test_image)
|
||||
self.assertEmpty(detection_result.hand_landmarks)
|
||||
self.assertEmpty(detection_result.hand_world_landmarks)
|
||||
self.assertEmpty(detection_result.handedness)
|
||||
|
||||
def test_missing_result_callback(self):
|
||||
options = _HandLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'result callback must be provided'):
|
||||
with _HandLandmarker.create_from_options(options) as unused_landmarker:
|
||||
pass
|
||||
|
||||
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
|
||||
def test_illegal_result_callback(self, running_mode):
|
||||
options = _HandLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=running_mode,
|
||||
result_callback=mock.MagicMock())
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'result callback should not be provided'):
|
||||
with _HandLandmarker.create_from_options(options) as unused_landmarker:
|
||||
pass
|
||||
|
||||
def test_calling_detect_for_video_in_image_mode(self):
|
||||
options = _HandLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE)
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the video mode'):
|
||||
landmarker.detect_for_video(self.test_image, 0)
|
||||
|
||||
def test_calling_detect_async_in_image_mode(self):
|
||||
options = _HandLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE)
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the live stream mode'):
|
||||
landmarker.detect_async(self.test_image, 0)
|
||||
|
||||
def test_calling_detect_in_video_mode(self):
|
||||
options = _HandLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO)
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the image mode'):
|
||||
landmarker.detect(self.test_image)
|
||||
|
||||
def test_calling_detect_async_in_video_mode(self):
|
||||
options = _HandLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO)
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the live stream mode'):
|
||||
landmarker.detect_async(self.test_image, 0)
|
||||
|
||||
def test_detect_for_video_with_out_of_order_timestamp(self):
|
||||
options = _HandLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO)
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
unused_result = landmarker.detect_for_video(self.test_image, 1)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'):
|
||||
landmarker.detect_for_video(self.test_image, 0)
|
||||
|
||||
@parameterized.parameters(
|
||||
(_THUMB_UP_IMAGE, 0,
|
||||
_get_expected_hand_landmarker_result(_THUMB_UP_LANDMARKS)),
|
||||
(_POINTING_UP_IMAGE, 0,
|
||||
_get_expected_hand_landmarker_result(_POINTING_UP_LANDMARKS)),
|
||||
(_POINTING_UP_ROTATED_IMAGE, -90,
|
||||
_get_expected_hand_landmarker_result(_POINTING_UP_ROTATED_LANDMARKS)),
|
||||
(_NO_HANDS_IMAGE, 0, _HandLandmarkerResult([], [], [])))
|
||||
def test_detect_for_video(self, image_path, rotation, expected_result):
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(image_path))
|
||||
# Set rotation parameters using ImageProcessingOptions.
|
||||
image_processing_options = _ImageProcessingOptions(
|
||||
rotation_degrees=rotation)
|
||||
options = _HandLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO)
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
for timestamp in range(0, 300, 30):
|
||||
result = landmarker.detect_for_video(test_image, timestamp,
|
||||
image_processing_options)
|
||||
if (result.hand_landmarks and result.hand_world_landmarks and
|
||||
result.handedness):
|
||||
self._expect_hand_landmarker_results_correct(result, expected_result)
|
||||
else:
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_calling_detect_in_live_stream_mode(self):
|
||||
options = _HandLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock())
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the image mode'):
|
||||
landmarker.detect(self.test_image)
|
||||
|
||||
def test_calling_detect_for_video_in_live_stream_mode(self):
|
||||
options = _HandLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock())
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the video mode'):
|
||||
landmarker.detect_for_video(self.test_image, 0)
|
||||
|
||||
def test_detect_async_calls_with_illegal_timestamp(self):
|
||||
options = _HandLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock())
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
landmarker.detect_async(self.test_image, 100)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'):
|
||||
landmarker.detect_async(self.test_image, 0)
|
||||
|
||||
@parameterized.parameters(
|
||||
(_THUMB_UP_IMAGE, 0,
|
||||
_get_expected_hand_landmarker_result(_THUMB_UP_LANDMARKS)),
|
||||
(_POINTING_UP_IMAGE, 0,
|
||||
_get_expected_hand_landmarker_result(_POINTING_UP_LANDMARKS)),
|
||||
(_POINTING_UP_ROTATED_IMAGE, -90,
|
||||
_get_expected_hand_landmarker_result(_POINTING_UP_ROTATED_LANDMARKS)),
|
||||
(_NO_HANDS_IMAGE, 0, _HandLandmarkerResult([], [], [])))
|
||||
def test_detect_async_calls(self, image_path, rotation, expected_result):
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(image_path))
|
||||
# Set rotation parameters using ImageProcessingOptions.
|
||||
image_processing_options = _ImageProcessingOptions(
|
||||
rotation_degrees=rotation)
|
||||
observed_timestamp_ms = -1
|
||||
|
||||
def check_result(result: _HandLandmarkerResult, output_image: _Image,
|
||||
timestamp_ms: int):
|
||||
if (result.hand_landmarks and result.hand_world_landmarks and
|
||||
result.handedness):
|
||||
self._expect_hand_landmarker_results_correct(result, expected_result)
|
||||
else:
|
||||
self.assertEqual(result, expected_result)
|
||||
self.assertTrue(
|
||||
np.array_equal(output_image.numpy_view(), test_image.numpy_view()))
|
||||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||
self.observed_timestamp_ms = timestamp_ms
|
||||
|
||||
options = _HandLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=check_result)
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
for timestamp in range(0, 300, 30):
|
||||
landmarker.detect_async(test_image, timestamp, image_processing_options)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
+544
@@ -0,0 +1,544 @@
|
||||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for holistic landmarker."""
|
||||
|
||||
import enum
|
||||
from unittest import mock
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from google.protobuf import text_format
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.tasks.cc.vision.holistic_landmarker.proto import holistic_result_pb2
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
from mediapipe.tasks.python.vision import holistic_landmarker
|
||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||
|
||||
|
||||
HolisticLandmarkerResult = holistic_landmarker.HolisticLandmarkerResult
|
||||
_HolisticResultProto = holistic_result_pb2.HolisticResult
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_Image = image_module.Image
|
||||
_HolisticLandmarker = holistic_landmarker.HolisticLandmarker
|
||||
_HolisticLandmarkerOptions = holistic_landmarker.HolisticLandmarkerOptions
|
||||
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
|
||||
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE = 'holistic_landmarker.task'
|
||||
_POSE_IMAGE = 'male_full_height_hands.jpg'
|
||||
_CAT_IMAGE = 'cat.jpg'
|
||||
_EXPECTED_HOLISTIC_RESULT = 'male_full_height_hands_result_cpu.pbtxt'
|
||||
_IMAGE_WIDTH = 638
|
||||
_IMAGE_HEIGHT = 1000
|
||||
_LANDMARKS_MARGIN = 0.03
|
||||
_BLENDSHAPES_MARGIN = 0.13
|
||||
_VIDEO_LANDMARKS_MARGIN = 0.03
|
||||
_VIDEO_BLENDSHAPES_MARGIN = 0.31
|
||||
_LIVE_STREAM_LANDMARKS_MARGIN = 0.03
|
||||
_LIVE_STREAM_BLENDSHAPES_MARGIN = 0.31
|
||||
|
||||
|
||||
def _get_expected_holistic_landmarker_result(
|
||||
file_path: str,
|
||||
) -> HolisticLandmarkerResult:
|
||||
holistic_result_file_path = test_utils.get_test_data_path(file_path)
|
||||
with open(holistic_result_file_path, 'rb') as f:
|
||||
holistic_result_proto = _HolisticResultProto()
|
||||
# Use this if a .pb file is available.
|
||||
# holistic_result_proto.ParseFromString(f.read())
|
||||
text_format.Parse(f.read(), holistic_result_proto)
|
||||
holistic_landmarker_result = HolisticLandmarkerResult.create_from_pb2(
|
||||
holistic_result_proto
|
||||
)
|
||||
return holistic_landmarker_result
|
||||
|
||||
|
||||
class ModelFileType(enum.Enum):
|
||||
FILE_CONTENT = 1
|
||||
FILE_NAME = 2
|
||||
|
||||
|
||||
class HolisticLandmarkerTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(_POSE_IMAGE)
|
||||
)
|
||||
self.model_path = test_utils.get_test_data_path(
|
||||
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE
|
||||
)
|
||||
|
||||
def _expect_landmarks_correct(
|
||||
self, actual_landmarks, expected_landmarks, margin
|
||||
):
|
||||
# Expects to have the same number of landmarks detected.
|
||||
self.assertLen(actual_landmarks, len(expected_landmarks))
|
||||
|
||||
for i, elem in enumerate(actual_landmarks):
|
||||
self.assertAlmostEqual(elem.x, expected_landmarks[i].x, delta=margin)
|
||||
self.assertAlmostEqual(elem.y, expected_landmarks[i].y, delta=margin)
|
||||
|
||||
def _expect_blendshapes_correct(
|
||||
self, actual_blendshapes, expected_blendshapes, margin
|
||||
):
|
||||
# Expects to have the same number of blendshapes.
|
||||
self.assertLen(actual_blendshapes, len(expected_blendshapes))
|
||||
|
||||
for i, elem in enumerate(actual_blendshapes):
|
||||
self.assertEqual(elem.index, expected_blendshapes[i].index)
|
||||
self.assertEqual(
|
||||
elem.category_name, expected_blendshapes[i].category_name
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
elem.score,
|
||||
expected_blendshapes[i].score,
|
||||
delta=margin,
|
||||
)
|
||||
|
||||
def _expect_holistic_landmarker_results_correct(
|
||||
self,
|
||||
actual_result: HolisticLandmarkerResult,
|
||||
expected_result: HolisticLandmarkerResult,
|
||||
output_segmentation_mask: bool,
|
||||
landmarks_margin: float,
|
||||
blendshapes_margin: float,
|
||||
):
|
||||
self._expect_landmarks_correct(
|
||||
actual_result.pose_landmarks,
|
||||
expected_result.pose_landmarks,
|
||||
landmarks_margin,
|
||||
)
|
||||
self._expect_landmarks_correct(
|
||||
actual_result.face_landmarks,
|
||||
expected_result.face_landmarks,
|
||||
landmarks_margin,
|
||||
)
|
||||
self._expect_blendshapes_correct(
|
||||
actual_result.face_blendshapes,
|
||||
expected_result.face_blendshapes,
|
||||
blendshapes_margin,
|
||||
)
|
||||
if output_segmentation_mask:
|
||||
self.assertIsInstance(actual_result.segmentation_mask, _Image)
|
||||
self.assertEqual(actual_result.segmentation_mask.width, _IMAGE_WIDTH)
|
||||
self.assertEqual(actual_result.segmentation_mask.height, _IMAGE_HEIGHT)
|
||||
else:
|
||||
self.assertIsNone(actual_result.segmentation_mask)
|
||||
|
||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||
# Creates with default option and valid model file successfully.
|
||||
with _HolisticLandmarker.create_from_model_path(
|
||||
self.model_path
|
||||
) as landmarker:
|
||||
self.assertIsInstance(landmarker, _HolisticLandmarker)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_path(self):
|
||||
# Creates with options containing model file successfully.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _HolisticLandmarkerOptions(base_options=base_options)
|
||||
with _HolisticLandmarker.create_from_options(options) as landmarker:
|
||||
self.assertIsInstance(landmarker, _HolisticLandmarker)
|
||||
|
||||
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||
# Invalid empty model path.
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'
|
||||
):
|
||||
base_options = _BaseOptions(
|
||||
model_asset_path='/path/to/invalid/model.tflite'
|
||||
)
|
||||
options = _HolisticLandmarkerOptions(base_options=base_options)
|
||||
_HolisticLandmarker.create_from_options(options)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_content(self):
|
||||
# Creates with options containing model content successfully.
|
||||
with open(self.model_path, 'rb') as f:
|
||||
base_options = _BaseOptions(model_asset_buffer=f.read())
|
||||
options = _HolisticLandmarkerOptions(base_options=base_options)
|
||||
landmarker = _HolisticLandmarker.create_from_options(options)
|
||||
self.assertIsInstance(landmarker, _HolisticLandmarker)
|
||||
|
||||
@parameterized.parameters(
|
||||
(
|
||||
ModelFileType.FILE_NAME,
|
||||
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE,
|
||||
False,
|
||||
_get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT),
|
||||
),
|
||||
(
|
||||
ModelFileType.FILE_CONTENT,
|
||||
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE,
|
||||
False,
|
||||
_get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT),
|
||||
),
|
||||
(
|
||||
ModelFileType.FILE_NAME,
|
||||
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE,
|
||||
True,
|
||||
_get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT),
|
||||
),
|
||||
(
|
||||
ModelFileType.FILE_CONTENT,
|
||||
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE,
|
||||
True,
|
||||
_get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT),
|
||||
),
|
||||
)
|
||||
def test_detect(
|
||||
self,
|
||||
model_file_type,
|
||||
model_name,
|
||||
output_segmentation_mask,
|
||||
expected_holistic_landmarker_result,
|
||||
):
|
||||
# Creates holistic landmarker.
|
||||
model_path = test_utils.get_test_data_path(model_name)
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _HolisticLandmarkerOptions(
|
||||
base_options=base_options,
|
||||
output_face_blendshapes=True
|
||||
if expected_holistic_landmarker_result.face_blendshapes
|
||||
else False,
|
||||
output_segmentation_mask=output_segmentation_mask,
|
||||
)
|
||||
landmarker = _HolisticLandmarker.create_from_options(options)
|
||||
|
||||
# Performs holistic landmarks detection on the input.
|
||||
detection_result = landmarker.detect(self.test_image)
|
||||
self._expect_holistic_landmarker_results_correct(
|
||||
detection_result,
|
||||
expected_holistic_landmarker_result,
|
||||
output_segmentation_mask,
|
||||
_LANDMARKS_MARGIN,
|
||||
_BLENDSHAPES_MARGIN,
|
||||
)
|
||||
# Closes the holistic landmarker explicitly when the holistic landmarker is
|
||||
# not used in a context.
|
||||
landmarker.close()
|
||||
|
||||
@parameterized.parameters(
|
||||
(
|
||||
ModelFileType.FILE_NAME,
|
||||
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE,
|
||||
False,
|
||||
_get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT),
|
||||
),
|
||||
(
|
||||
ModelFileType.FILE_CONTENT,
|
||||
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE,
|
||||
True,
|
||||
_get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT),
|
||||
),
|
||||
)
|
||||
def test_detect_in_context(
|
||||
self,
|
||||
model_file_type,
|
||||
model_name,
|
||||
output_segmentation_mask,
|
||||
expected_holistic_landmarker_result,
|
||||
):
|
||||
# Creates holistic landmarker.
|
||||
model_path = test_utils.get_test_data_path(model_name)
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _HolisticLandmarkerOptions(
|
||||
base_options=base_options,
|
||||
output_face_blendshapes=True
|
||||
if expected_holistic_landmarker_result.face_blendshapes
|
||||
else False,
|
||||
output_segmentation_mask=output_segmentation_mask,
|
||||
)
|
||||
|
||||
with _HolisticLandmarker.create_from_options(options) as landmarker:
|
||||
# Performs holistic landmarks detection on the input.
|
||||
detection_result = landmarker.detect(self.test_image)
|
||||
self._expect_holistic_landmarker_results_correct(
|
||||
detection_result,
|
||||
expected_holistic_landmarker_result,
|
||||
output_segmentation_mask,
|
||||
_LANDMARKS_MARGIN,
|
||||
_BLENDSHAPES_MARGIN,
|
||||
)
|
||||
|
||||
def test_empty_detection_outputs(self):
|
||||
options = _HolisticLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path)
|
||||
)
|
||||
with _HolisticLandmarker.create_from_options(options) as landmarker:
|
||||
# Load the cat image.
|
||||
cat_test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(_CAT_IMAGE)
|
||||
)
|
||||
# Performs holistic landmarks detection on the input.
|
||||
detection_result = landmarker.detect(cat_test_image)
|
||||
self.assertEmpty(detection_result.face_landmarks)
|
||||
self.assertEmpty(detection_result.pose_landmarks)
|
||||
self.assertEmpty(detection_result.pose_world_landmarks)
|
||||
self.assertEmpty(detection_result.left_hand_landmarks)
|
||||
self.assertEmpty(detection_result.left_hand_world_landmarks)
|
||||
self.assertEmpty(detection_result.right_hand_landmarks)
|
||||
self.assertEmpty(detection_result.right_hand_world_landmarks)
|
||||
self.assertIsNone(detection_result.face_blendshapes)
|
||||
self.assertIsNone(detection_result.segmentation_mask)
|
||||
|
||||
def test_missing_result_callback(self):
|
||||
options = _HolisticLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'result callback must be provided'
|
||||
):
|
||||
with _HolisticLandmarker.create_from_options(
|
||||
options
|
||||
) as unused_landmarker:
|
||||
pass
|
||||
|
||||
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
|
||||
def test_illegal_result_callback(self, running_mode):
|
||||
options = _HolisticLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=running_mode,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'result callback should not be provided'
|
||||
):
|
||||
with _HolisticLandmarker.create_from_options(
|
||||
options
|
||||
) as unused_landmarker:
|
||||
pass
|
||||
|
||||
def test_calling_detect_for_video_in_image_mode(self):
|
||||
options = _HolisticLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE,
|
||||
)
|
||||
with _HolisticLandmarker.create_from_options(options) as landmarker:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the video mode'
|
||||
):
|
||||
landmarker.detect_for_video(self.test_image, 0)
|
||||
|
||||
def test_calling_detect_async_in_image_mode(self):
|
||||
options = _HolisticLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE,
|
||||
)
|
||||
with _HolisticLandmarker.create_from_options(options) as landmarker:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the live stream mode'
|
||||
):
|
||||
landmarker.detect_async(self.test_image, 0)
|
||||
|
||||
def test_calling_detect_in_video_mode(self):
|
||||
options = _HolisticLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _HolisticLandmarker.create_from_options(options) as landmarker:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the image mode'
|
||||
):
|
||||
landmarker.detect(self.test_image)
|
||||
|
||||
def test_calling_detect_async_in_video_mode(self):
|
||||
options = _HolisticLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _HolisticLandmarker.create_from_options(options) as landmarker:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the live stream mode'
|
||||
):
|
||||
landmarker.detect_async(self.test_image, 0)
|
||||
|
||||
def test_detect_for_video_with_out_of_order_timestamp(self):
|
||||
options = _HolisticLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _HolisticLandmarker.create_from_options(options) as landmarker:
|
||||
unused_result = landmarker.detect_for_video(self.test_image, 1)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'
|
||||
):
|
||||
landmarker.detect_for_video(self.test_image, 0)
|
||||
|
||||
@parameterized.parameters(
|
||||
(
|
||||
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE,
|
||||
False,
|
||||
_get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT),
|
||||
),
|
||||
(
|
||||
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE,
|
||||
True,
|
||||
_get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT),
|
||||
),
|
||||
)
|
||||
def test_detect_for_video(
|
||||
self,
|
||||
model_name,
|
||||
output_segmentation_mask,
|
||||
expected_holistic_landmarker_result,
|
||||
):
|
||||
# Creates holistic landmarker.
|
||||
model_path = test_utils.get_test_data_path(model_name)
|
||||
base_options = _BaseOptions(model_asset_path=model_path)
|
||||
options = _HolisticLandmarkerOptions(
|
||||
base_options=base_options,
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
output_face_blendshapes=True
|
||||
if expected_holistic_landmarker_result.face_blendshapes
|
||||
else False,
|
||||
output_segmentation_mask=output_segmentation_mask,
|
||||
)
|
||||
|
||||
with _HolisticLandmarker.create_from_options(options) as landmarker:
|
||||
for timestamp in range(0, 300, 30):
|
||||
# Performs holistic landmarks detection on the input.
|
||||
detection_result = landmarker.detect_for_video(
|
||||
self.test_image, timestamp
|
||||
)
|
||||
# Comparing results.
|
||||
self._expect_holistic_landmarker_results_correct(
|
||||
detection_result,
|
||||
expected_holistic_landmarker_result,
|
||||
output_segmentation_mask,
|
||||
_VIDEO_LANDMARKS_MARGIN,
|
||||
_VIDEO_BLENDSHAPES_MARGIN,
|
||||
)
|
||||
|
||||
def test_calling_detect_in_live_stream_mode(self):
|
||||
options = _HolisticLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _HolisticLandmarker.create_from_options(options) as landmarker:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the image mode'
|
||||
):
|
||||
landmarker.detect(self.test_image)
|
||||
|
||||
def test_calling_detect_for_video_in_live_stream_mode(self):
|
||||
options = _HolisticLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _HolisticLandmarker.create_from_options(options) as landmarker:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the video mode'
|
||||
):
|
||||
landmarker.detect_for_video(self.test_image, 0)
|
||||
|
||||
def test_detect_async_calls_with_illegal_timestamp(self):
|
||||
options = _HolisticLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _HolisticLandmarker.create_from_options(options) as landmarker:
|
||||
landmarker.detect_async(self.test_image, 100)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'
|
||||
):
|
||||
landmarker.detect_async(self.test_image, 0)
|
||||
|
||||
@parameterized.parameters(
|
||||
(
|
||||
_POSE_IMAGE,
|
||||
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE,
|
||||
False,
|
||||
_get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT),
|
||||
),
|
||||
(
|
||||
_POSE_IMAGE,
|
||||
_HOLISTIC_LANDMARKER_BUNDLE_ASSET_FILE,
|
||||
True,
|
||||
_get_expected_holistic_landmarker_result(_EXPECTED_HOLISTIC_RESULT),
|
||||
),
|
||||
)
|
||||
def test_detect_async_calls(
|
||||
self,
|
||||
image_path,
|
||||
model_name,
|
||||
output_segmentation_mask,
|
||||
expected_holistic_landmarker_result,
|
||||
):
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(image_path)
|
||||
)
|
||||
observed_timestamp_ms = -1
|
||||
|
||||
def check_result(
|
||||
result: HolisticLandmarkerResult,
|
||||
output_image: _Image,
|
||||
timestamp_ms: int,
|
||||
):
|
||||
# Comparing results.
|
||||
self._expect_holistic_landmarker_results_correct(
|
||||
result,
|
||||
expected_holistic_landmarker_result,
|
||||
output_segmentation_mask,
|
||||
_LIVE_STREAM_LANDMARKS_MARGIN,
|
||||
_LIVE_STREAM_BLENDSHAPES_MARGIN,
|
||||
)
|
||||
self.assertTrue(
|
||||
np.array_equal(output_image.numpy_view(), test_image.numpy_view())
|
||||
)
|
||||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||
self.observed_timestamp_ms = timestamp_ms
|
||||
|
||||
model_path = test_utils.get_test_data_path(model_name)
|
||||
options = _HolisticLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
output_face_blendshapes=True
|
||||
if expected_holistic_landmarker_result.face_blendshapes
|
||||
else False,
|
||||
output_segmentation_mask=output_segmentation_mask,
|
||||
result_callback=check_result,
|
||||
)
|
||||
with _HolisticLandmarker.create_from_options(options) as landmarker:
|
||||
for timestamp in range(0, 300, 30):
|
||||
landmarker.detect_async(test_image, timestamp)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
+657
@@ -0,0 +1,657 @@
|
||||
# Copyright 2022 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for image classifier."""
|
||||
|
||||
import enum
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mediapipe.python._framework_bindings import image
|
||||
from mediapipe.tasks.python.components.containers import category as category_module
|
||||
from mediapipe.tasks.python.components.containers import classification_result as classification_result_module
|
||||
from mediapipe.tasks.python.components.containers import rect
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
from mediapipe.tasks.python.vision import image_classifier
|
||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
||||
|
||||
ImageClassifierResult = classification_result_module.ClassificationResult
|
||||
_Rect = rect.Rect
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_Category = category_module.Category
|
||||
_Classifications = classification_result_module.Classifications
|
||||
_Image = image.Image
|
||||
_ImageClassifier = image_classifier.ImageClassifier
|
||||
_ImageClassifierOptions = image_classifier.ImageClassifierOptions
|
||||
_RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
|
||||
_MODEL_FILE = 'mobilenet_v2_1.0_224.tflite'
|
||||
_IMAGE_FILE = 'burger.jpg'
|
||||
_ALLOW_LIST = ['cheeseburger', 'guacamole']
|
||||
_DENY_LIST = ['cheeseburger']
|
||||
_SCORE_THRESHOLD = 0.5
|
||||
_MAX_RESULTS = 3
|
||||
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
||||
|
||||
|
||||
def _generate_empty_results() -> ImageClassifierResult:
|
||||
return ImageClassifierResult(
|
||||
classifications=[
|
||||
_Classifications(categories=[], head_index=0, head_name='probability')
|
||||
],
|
||||
timestamp_ms=0,
|
||||
)
|
||||
|
||||
|
||||
def _generate_burger_results(timestamp_ms=0) -> ImageClassifierResult:
|
||||
return ImageClassifierResult(
|
||||
classifications=[
|
||||
_Classifications(
|
||||
categories=[
|
||||
_Category(
|
||||
index=934,
|
||||
score=0.793959,
|
||||
display_name='',
|
||||
category_name='cheeseburger',
|
||||
),
|
||||
_Category(
|
||||
index=932,
|
||||
score=0.0273929,
|
||||
display_name='',
|
||||
category_name='bagel',
|
||||
),
|
||||
_Category(
|
||||
index=925,
|
||||
score=0.0193408,
|
||||
display_name='',
|
||||
category_name='guacamole',
|
||||
),
|
||||
_Category(
|
||||
index=963,
|
||||
score=0.00632786,
|
||||
display_name='',
|
||||
category_name='meat loaf',
|
||||
),
|
||||
],
|
||||
head_index=0,
|
||||
head_name='probability',
|
||||
)
|
||||
],
|
||||
timestamp_ms=timestamp_ms,
|
||||
)
|
||||
|
||||
|
||||
def _generate_soccer_ball_results(timestamp_ms=0) -> ImageClassifierResult:
|
||||
return ImageClassifierResult(
|
||||
classifications=[
|
||||
_Classifications(
|
||||
categories=[
|
||||
_Category(
|
||||
index=806,
|
||||
score=0.996527,
|
||||
display_name='',
|
||||
category_name='soccer ball',
|
||||
)
|
||||
],
|
||||
head_index=0,
|
||||
head_name='probability',
|
||||
)
|
||||
],
|
||||
timestamp_ms=timestamp_ms,
|
||||
)
|
||||
|
||||
|
||||
class ModelFileType(enum.Enum):
|
||||
FILE_CONTENT = 1
|
||||
FILE_NAME = 2
|
||||
|
||||
|
||||
class ImageClassifierTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _IMAGE_FILE))
|
||||
)
|
||||
self.model_path = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _MODEL_FILE)
|
||||
)
|
||||
|
||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||
# Creates with default option and valid model file successfully.
|
||||
with _ImageClassifier.create_from_model_path(self.model_path) as classifier:
|
||||
self.assertIsInstance(classifier, _ImageClassifier)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_path(self):
|
||||
# Creates with options containing model file successfully.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _ImageClassifierOptions(base_options=base_options)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
self.assertIsInstance(classifier, _ImageClassifier)
|
||||
|
||||
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'
|
||||
):
|
||||
base_options = _BaseOptions(
|
||||
model_asset_path='/path/to/invalid/model.tflite'
|
||||
)
|
||||
options = _ImageClassifierOptions(base_options=base_options)
|
||||
_ImageClassifier.create_from_options(options)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_content(self):
|
||||
# Creates with options containing model content successfully.
|
||||
with open(self.model_path, 'rb') as f:
|
||||
base_options = _BaseOptions(model_asset_buffer=f.read())
|
||||
options = _ImageClassifierOptions(base_options=base_options)
|
||||
classifier = _ImageClassifier.create_from_options(options)
|
||||
self.assertIsInstance(classifier, _ImageClassifier)
|
||||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME, 4, _generate_burger_results()),
|
||||
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results()),
|
||||
)
|
||||
def test_classify(
|
||||
self, model_file_type, max_results, expected_classification_result
|
||||
):
|
||||
# Creates classifier.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=base_options, max_results=max_results
|
||||
)
|
||||
classifier = _ImageClassifier.create_from_options(options)
|
||||
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
# Comparing results.
|
||||
test_utils.assert_proto_equals(
|
||||
self, image_result.to_pb2(), expected_classification_result.to_pb2()
|
||||
)
|
||||
# Closes the classifier explicitly when the classifier is not used in
|
||||
# a context.
|
||||
classifier.close()
|
||||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME, 4, _generate_burger_results()),
|
||||
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results()),
|
||||
)
|
||||
def test_classify_in_context(
|
||||
self, model_file_type, max_results, expected_classification_result
|
||||
):
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=base_options, max_results=max_results
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
# Comparing results.
|
||||
test_utils.assert_proto_equals(
|
||||
self, image_result.to_pb2(), expected_classification_result.to_pb2()
|
||||
)
|
||||
|
||||
def test_classify_succeeds_with_region_of_interest(self):
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _ImageClassifierOptions(base_options=base_options, max_results=1)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')
|
||||
)
|
||||
)
|
||||
# Region-of-interest around the soccer ball.
|
||||
roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345)
|
||||
image_processing_options = _ImageProcessingOptions(roi)
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(test_image, image_processing_options)
|
||||
# Comparing results.
|
||||
test_utils.assert_proto_equals(
|
||||
self, image_result.to_pb2(), _generate_soccer_ball_results().to_pb2()
|
||||
)
|
||||
|
||||
def test_classify_succeeds_with_rotation(self):
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _ImageClassifierOptions(base_options=base_options, max_results=3)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, 'burger_rotated.jpg')
|
||||
)
|
||||
)
|
||||
# Specify a 90° anti-clockwise rotation.
|
||||
image_processing_options = _ImageProcessingOptions(None, -90)
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(test_image, image_processing_options)
|
||||
# Comparing results.
|
||||
expected = ImageClassifierResult(
|
||||
classifications=[
|
||||
_Classifications(
|
||||
categories=[
|
||||
_Category(
|
||||
index=934,
|
||||
score=0.754467,
|
||||
display_name='',
|
||||
category_name='cheeseburger',
|
||||
),
|
||||
_Category(
|
||||
index=925,
|
||||
score=0.0288028,
|
||||
display_name='',
|
||||
category_name='guacamole',
|
||||
),
|
||||
_Category(
|
||||
index=932,
|
||||
score=0.0286119,
|
||||
display_name='',
|
||||
category_name='bagel',
|
||||
),
|
||||
],
|
||||
head_index=0,
|
||||
head_name='probability',
|
||||
)
|
||||
],
|
||||
timestamp_ms=0,
|
||||
)
|
||||
test_utils.assert_proto_equals(
|
||||
self, image_result.to_pb2(), expected.to_pb2()
|
||||
)
|
||||
|
||||
def test_classify_succeeds_with_region_of_interest_and_rotation(self):
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _ImageClassifierOptions(base_options=base_options, max_results=1)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, 'multi_objects_rotated.jpg')
|
||||
)
|
||||
)
|
||||
# Region-of-interest around the soccer ball, with 90° anti-clockwise
|
||||
# rotation.
|
||||
roi = _Rect(left=0.2655, top=0.45, right=0.6925, bottom=0.614)
|
||||
image_processing_options = _ImageProcessingOptions(roi, -90)
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(test_image, image_processing_options)
|
||||
# Comparing results.
|
||||
expected = ImageClassifierResult(
|
||||
classifications=[
|
||||
_Classifications(
|
||||
categories=[
|
||||
_Category(
|
||||
index=806,
|
||||
score=0.997684,
|
||||
display_name='',
|
||||
category_name='soccer ball',
|
||||
),
|
||||
],
|
||||
head_index=0,
|
||||
head_name='probability',
|
||||
)
|
||||
],
|
||||
timestamp_ms=0,
|
||||
)
|
||||
test_utils.assert_proto_equals(
|
||||
self, image_result.to_pb2(), expected.to_pb2()
|
||||
)
|
||||
|
||||
def test_score_threshold_option(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
score_threshold=_SCORE_THRESHOLD,
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
classifications = image_result.classifications
|
||||
|
||||
for classification in classifications:
|
||||
for category in classification.categories:
|
||||
score = category.score
|
||||
self.assertGreaterEqual(
|
||||
score,
|
||||
_SCORE_THRESHOLD,
|
||||
(
|
||||
'Classification with score lower than threshold found. '
|
||||
f'{classification}'
|
||||
),
|
||||
)
|
||||
|
||||
def test_max_results_option(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
score_threshold=_SCORE_THRESHOLD,
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
categories = image_result.classifications[0].categories
|
||||
|
||||
self.assertLessEqual(
|
||||
len(categories), _MAX_RESULTS, 'Too many results returned.'
|
||||
)
|
||||
|
||||
def test_allow_list_option(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
category_allowlist=_ALLOW_LIST,
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
classifications = image_result.classifications
|
||||
|
||||
for classification in classifications:
|
||||
for category in classification.categories:
|
||||
label = category.category_name
|
||||
self.assertIn(
|
||||
label,
|
||||
_ALLOW_LIST,
|
||||
f'Label {label} found but not in label allow list',
|
||||
)
|
||||
|
||||
def test_deny_list_option(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
category_denylist=_DENY_LIST,
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
classifications = image_result.classifications
|
||||
|
||||
for classification in classifications:
|
||||
for category in classification.categories:
|
||||
label = category.category_name
|
||||
self.assertNotIn(
|
||||
label, _DENY_LIST, f'Label {label} found but in deny list.'
|
||||
)
|
||||
|
||||
def test_combined_allowlist_and_denylist(self):
|
||||
# Fails with combined allowlist and denylist
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r'`category_allowlist` and `category_denylist` are mutually '
|
||||
r'exclusive options.',
|
||||
):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
category_allowlist=['foo'],
|
||||
category_denylist=['bar'],
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as unused_classifier:
|
||||
pass
|
||||
|
||||
def test_empty_classification_outputs(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
score_threshold=1,
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
self.assertEmpty(image_result.classifications[0].categories)
|
||||
|
||||
def test_missing_result_callback(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'result callback must be provided'
|
||||
):
|
||||
with _ImageClassifier.create_from_options(options) as unused_classifier:
|
||||
pass
|
||||
|
||||
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
|
||||
def test_illegal_result_callback(self, running_mode):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=running_mode,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'result callback should not be provided'
|
||||
):
|
||||
with _ImageClassifier.create_from_options(options) as unused_classifier:
|
||||
pass
|
||||
|
||||
def test_calling_classify_for_video_in_image_mode(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE,
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the video mode'
|
||||
):
|
||||
classifier.classify_for_video(self.test_image, 0)
|
||||
|
||||
def test_calling_classify_async_in_image_mode(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE,
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the live stream mode'
|
||||
):
|
||||
classifier.classify_async(self.test_image, 0)
|
||||
|
||||
def test_calling_classify_in_video_mode(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the image mode'
|
||||
):
|
||||
classifier.classify(self.test_image)
|
||||
|
||||
def test_calling_classify_async_in_video_mode(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the live stream mode'
|
||||
):
|
||||
classifier.classify_async(self.test_image, 0)
|
||||
|
||||
def test_classify_for_video_with_out_of_order_timestamp(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
unused_result = classifier.classify_for_video(self.test_image, 1)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'
|
||||
):
|
||||
classifier.classify_for_video(self.test_image, 0)
|
||||
|
||||
def test_classify_for_video(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
max_results=4,
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
for timestamp in range(0, 300, 30):
|
||||
classification_result = classifier.classify_for_video(
|
||||
self.test_image, timestamp
|
||||
)
|
||||
test_utils.assert_proto_equals(
|
||||
self,
|
||||
classification_result.to_pb2(),
|
||||
_generate_burger_results(timestamp).to_pb2(),
|
||||
)
|
||||
|
||||
def test_classify_for_video_succeeds_with_region_of_interest(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
max_results=1,
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')
|
||||
)
|
||||
)
|
||||
# Region-of-interest around the soccer ball.
|
||||
roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345)
|
||||
image_processing_options = _ImageProcessingOptions(roi)
|
||||
for timestamp in range(0, 300, 30):
|
||||
classification_result = classifier.classify_for_video(
|
||||
test_image, timestamp, image_processing_options
|
||||
)
|
||||
test_utils.assert_proto_equals(
|
||||
self,
|
||||
classification_result.to_pb2(),
|
||||
_generate_soccer_ball_results(timestamp).to_pb2(),
|
||||
)
|
||||
|
||||
def test_calling_classify_in_live_stream_mode(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the image mode'
|
||||
):
|
||||
classifier.classify(self.test_image)
|
||||
|
||||
def test_calling_classify_for_video_in_live_stream_mode(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the video mode'
|
||||
):
|
||||
classifier.classify_for_video(self.test_image, 0)
|
||||
|
||||
def test_classify_async_calls_with_illegal_timestamp(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
max_results=4,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
classifier.classify_async(self.test_image, 100)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'
|
||||
):
|
||||
classifier.classify_async(self.test_image, 0)
|
||||
|
||||
@parameterized.parameters(
|
||||
(0, _generate_burger_results()), (1, _generate_empty_results())
|
||||
)
|
||||
def test_classify_async_calls(self, threshold, expected_result):
|
||||
observed_timestamp_ms = -1
|
||||
|
||||
def check_result(
|
||||
result: ImageClassifierResult, output_image: _Image, timestamp_ms: int
|
||||
):
|
||||
test_utils.assert_proto_equals(
|
||||
self, result.to_pb2(), expected_result.to_pb2()
|
||||
)
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
output_image.numpy_view(), self.test_image.numpy_view()
|
||||
)
|
||||
)
|
||||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||
self.observed_timestamp_ms = timestamp_ms
|
||||
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
max_results=4,
|
||||
score_threshold=threshold,
|
||||
result_callback=check_result,
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
classifier.classify_async(self.test_image, 0)
|
||||
|
||||
def test_classify_async_succeeds_with_region_of_interest(self):
|
||||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, 'multi_objects.jpg')
|
||||
)
|
||||
)
|
||||
# Region-of-interest around the soccer ball.
|
||||
roi = _Rect(left=0.45, top=0.3075, right=0.614, bottom=0.7345)
|
||||
image_processing_options = _ImageProcessingOptions(roi)
|
||||
observed_timestamp_ms = -1
|
||||
|
||||
def check_result(
|
||||
result: ImageClassifierResult, output_image: _Image, timestamp_ms: int
|
||||
):
|
||||
test_utils.assert_proto_equals(
|
||||
self, result.to_pb2(), _generate_soccer_ball_results(100).to_pb2()
|
||||
)
|
||||
self.assertEqual(output_image.width, test_image.width)
|
||||
self.assertEqual(output_image.height, test_image.height)
|
||||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||
self.observed_timestamp_ms = timestamp_ms
|
||||
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
max_results=1,
|
||||
result_callback=check_result,
|
||||
)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
classifier.classify_async(test_image, 100, image_processing_options)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
+423
@@ -0,0 +1,423 @@
|
||||
# Copyright 2022 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for image embedder."""
|
||||
|
||||
import enum
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.tasks.python.components.containers import embedding_result as embedding_result_module
|
||||
from mediapipe.tasks.python.components.containers import rect
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
from mediapipe.tasks.python.vision import image_embedder
|
||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||
|
||||
_Rect = rect.Rect
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_Embedding = embedding_result_module.Embedding
|
||||
_Image = image_module.Image
|
||||
_ImageEmbedder = image_embedder.ImageEmbedder
|
||||
_ImageEmbedderOptions = image_embedder.ImageEmbedderOptions
|
||||
_ImageEmbedderResult = image_embedder.ImageEmbedderResult
|
||||
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
|
||||
_MODEL_FILE = 'mobilenet_v3_small_100_224_embedder.tflite'
|
||||
_BURGER_IMAGE_FILE = 'burger.jpg'
|
||||
_BURGER_CROPPED_IMAGE_FILE = 'burger_crop.jpg'
|
||||
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
||||
# Tolerance for embedding vector coordinate values.
|
||||
_EPSILON = 1e-4
|
||||
# Tolerance for cosine similarity evaluation.
|
||||
_SIMILARITY_TOLERANCE = 1e-6
|
||||
|
||||
|
||||
class ModelFileType(enum.Enum):
|
||||
FILE_CONTENT = 1
|
||||
FILE_NAME = 2
|
||||
|
||||
|
||||
class ImageEmbedderTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _BURGER_IMAGE_FILE)))
|
||||
self.test_cropped_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _BURGER_CROPPED_IMAGE_FILE)))
|
||||
self.model_path = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _MODEL_FILE))
|
||||
|
||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||
# Creates with default option and valid model file successfully.
|
||||
with _ImageEmbedder.create_from_model_path(self.model_path) as embedder:
|
||||
self.assertIsInstance(embedder, _ImageEmbedder)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_path(self):
|
||||
# Creates with options containing model file successfully.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _ImageEmbedderOptions(base_options=base_options)
|
||||
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||
self.assertIsInstance(embedder, _ImageEmbedder)
|
||||
|
||||
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
|
||||
base_options = _BaseOptions(
|
||||
model_asset_path='/path/to/invalid/model.tflite')
|
||||
options = _ImageEmbedderOptions(base_options=base_options)
|
||||
_ImageEmbedder.create_from_options(options)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_content(self):
|
||||
# Creates with options containing model content successfully.
|
||||
with open(self.model_path, 'rb') as f:
|
||||
base_options = _BaseOptions(model_asset_buffer=f.read())
|
||||
options = _ImageEmbedderOptions(base_options=base_options)
|
||||
embedder = _ImageEmbedder.create_from_options(options)
|
||||
self.assertIsInstance(embedder, _ImageEmbedder)
|
||||
|
||||
def _check_embedding_value(self, result, expected_first_value):
|
||||
# Check embedding first value.
|
||||
self.assertAlmostEqual(
|
||||
result.embeddings[0].embedding[0], expected_first_value, delta=_EPSILON)
|
||||
|
||||
def _check_embedding_size(self, result, quantize, expected_embedding_size):
|
||||
# Check embedding size.
|
||||
self.assertLen(result.embeddings, 1)
|
||||
embedding_result = result.embeddings[0]
|
||||
self.assertLen(embedding_result.embedding, expected_embedding_size)
|
||||
if quantize:
|
||||
self.assertEqual(embedding_result.embedding.dtype, np.uint8)
|
||||
else:
|
||||
self.assertEqual(embedding_result.embedding.dtype, float)
|
||||
|
||||
def _check_cosine_similarity(self, result0, result1, expected_similarity):
|
||||
# Checks cosine similarity.
|
||||
similarity = _ImageEmbedder.cosine_similarity(result0.embeddings[0],
|
||||
result1.embeddings[0])
|
||||
self.assertAlmostEqual(
|
||||
similarity, expected_similarity, delta=_SIMILARITY_TOLERANCE)
|
||||
|
||||
@parameterized.parameters(
|
||||
(
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
ModelFileType.FILE_NAME,
|
||||
0.925519,
|
||||
1024,
|
||||
(-0.2101883, -0.193027),
|
||||
),
|
||||
(
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
ModelFileType.FILE_NAME,
|
||||
0.925519,
|
||||
1024,
|
||||
(-0.0142344, -0.0131606),
|
||||
),
|
||||
(
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
ModelFileType.FILE_NAME,
|
||||
0.926791,
|
||||
1024,
|
||||
(229, 231),
|
||||
),
|
||||
(
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
ModelFileType.FILE_CONTENT,
|
||||
0.999931,
|
||||
1024,
|
||||
(-0.195062, -0.193027),
|
||||
),
|
||||
)
|
||||
def test_embed(self, l2_normalize, quantize, with_roi, model_file_type,
|
||||
expected_similarity, expected_size, expected_first_values):
|
||||
# Creates embedder.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=base_options, l2_normalize=l2_normalize, quantize=quantize)
|
||||
embedder = _ImageEmbedder.create_from_options(options)
|
||||
|
||||
image_processing_options = None
|
||||
if with_roi:
|
||||
# Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg".
|
||||
roi = _Rect(left=0, top=0, right=0.833333, bottom=1)
|
||||
image_processing_options = _ImageProcessingOptions(roi)
|
||||
|
||||
# Extracts both embeddings.
|
||||
image_result = embedder.embed(self.test_image, image_processing_options)
|
||||
crop_result = embedder.embed(self.test_cropped_image)
|
||||
|
||||
# Checks embeddings and cosine similarity.
|
||||
expected_result0_value, expected_result1_value = expected_first_values
|
||||
self._check_embedding_size(image_result, quantize, expected_size)
|
||||
self._check_embedding_size(crop_result, quantize, expected_size)
|
||||
self._check_embedding_value(image_result, expected_result0_value)
|
||||
self._check_embedding_value(crop_result, expected_result1_value)
|
||||
self._check_cosine_similarity(image_result, crop_result,
|
||||
expected_similarity)
|
||||
# Closes the embedder explicitly when the embedder is not used in
|
||||
# a context.
|
||||
embedder.close()
|
||||
|
||||
@parameterized.parameters(
|
||||
(False, False, ModelFileType.FILE_NAME, 0.925519),
|
||||
(False, False, ModelFileType.FILE_CONTENT, 0.925519))
|
||||
def test_embed_in_context(self, l2_normalize, quantize, model_file_type,
|
||||
expected_similarity):
|
||||
# Creates embedder.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=base_options, l2_normalize=l2_normalize, quantize=quantize)
|
||||
|
||||
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||
# Extracts both embeddings.
|
||||
image_result = embedder.embed(self.test_image)
|
||||
crop_result = embedder.embed(self.test_cropped_image)
|
||||
|
||||
# Checks cosine similarity.
|
||||
self._check_cosine_similarity(image_result, crop_result,
|
||||
expected_similarity)
|
||||
|
||||
def test_missing_result_callback(self):
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'result callback must be provided'):
|
||||
with _ImageEmbedder.create_from_options(options) as unused_embedder:
|
||||
pass
|
||||
|
||||
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
|
||||
def test_illegal_result_callback(self, running_mode):
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=running_mode,
|
||||
result_callback=mock.MagicMock())
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'result callback should not be provided'):
|
||||
with _ImageEmbedder.create_from_options(options) as unused_embedder:
|
||||
pass
|
||||
|
||||
def test_calling_embed_for_video_in_image_mode(self):
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE)
|
||||
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the video mode'):
|
||||
embedder.embed_for_video(self.test_image, 0)
|
||||
|
||||
def test_calling_embed_async_in_image_mode(self):
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE)
|
||||
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the live stream mode'):
|
||||
embedder.embed_async(self.test_image, 0)
|
||||
|
||||
def test_calling_embed_in_video_mode(self):
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO)
|
||||
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the image mode'):
|
||||
embedder.embed(self.test_image)
|
||||
|
||||
def test_calling_embed_async_in_video_mode(self):
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO)
|
||||
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the live stream mode'):
|
||||
embedder.embed_async(self.test_image, 0)
|
||||
|
||||
def test_embed_for_video_with_out_of_order_timestamp(self):
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO)
|
||||
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||
unused_result = embedder.embed_for_video(self.test_image, 1)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'):
|
||||
embedder.embed_for_video(self.test_image, 0)
|
||||
|
||||
def test_embed_for_video(self):
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO)
|
||||
with _ImageEmbedder.create_from_options(options) as embedder0, \
|
||||
_ImageEmbedder.create_from_options(options) as embedder1:
|
||||
for timestamp in range(0, 300, 30):
|
||||
# Extracts both embeddings.
|
||||
image_result = embedder0.embed_for_video(self.test_image, timestamp)
|
||||
crop_result = embedder1.embed_for_video(self.test_cropped_image,
|
||||
timestamp)
|
||||
# Checks cosine similarity.
|
||||
self._check_cosine_similarity(
|
||||
image_result, crop_result, expected_similarity=0.925519)
|
||||
|
||||
def test_embed_for_video_succeeds_with_region_of_interest(self):
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO)
|
||||
with _ImageEmbedder.create_from_options(options) as embedder0, \
|
||||
_ImageEmbedder.create_from_options(options) as embedder1:
|
||||
# Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg".
|
||||
roi = _Rect(left=0, top=0, right=0.833333, bottom=1)
|
||||
image_processing_options = _ImageProcessingOptions(roi)
|
||||
|
||||
for timestamp in range(0, 300, 30):
|
||||
# Extracts both embeddings.
|
||||
image_result = embedder0.embed_for_video(self.test_image, timestamp,
|
||||
image_processing_options)
|
||||
crop_result = embedder1.embed_for_video(self.test_cropped_image,
|
||||
timestamp)
|
||||
|
||||
# Checks cosine similarity.
|
||||
self._check_cosine_similarity(
|
||||
image_result, crop_result, expected_similarity=0.999931)
|
||||
|
||||
def test_calling_embed_in_live_stream_mode(self):
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock())
|
||||
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the image mode'):
|
||||
embedder.embed(self.test_image)
|
||||
|
||||
def test_calling_embed_for_video_in_live_stream_mode(self):
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock())
|
||||
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the video mode'):
|
||||
embedder.embed_for_video(self.test_image, 0)
|
||||
|
||||
def test_embed_async_calls_with_illegal_timestamp(self):
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock())
|
||||
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||
embedder.embed_async(self.test_image, 100)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'):
|
||||
embedder.embed_async(self.test_image, 0)
|
||||
|
||||
def test_embed_async_calls(self):
|
||||
# Get the embedding result for the cropped image.
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE)
|
||||
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||
crop_result = embedder.embed(self.test_cropped_image)
|
||||
|
||||
observed_timestamp_ms = -1
|
||||
|
||||
def check_result(result: _ImageEmbedderResult, output_image: _Image,
|
||||
timestamp_ms: int):
|
||||
# Checks cosine similarity.
|
||||
self._check_cosine_similarity(
|
||||
result, crop_result, expected_similarity=0.925519)
|
||||
self.assertTrue(
|
||||
np.array_equal(output_image.numpy_view(),
|
||||
self.test_image.numpy_view()))
|
||||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||
self.observed_timestamp_ms = timestamp_ms
|
||||
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=check_result)
|
||||
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||
for timestamp in range(0, 300, 30):
|
||||
embedder.embed_async(self.test_image, timestamp)
|
||||
|
||||
def test_embed_async_succeeds_with_region_of_interest(self):
|
||||
# Get the embedding result for the cropped image.
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE)
|
||||
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||
crop_result = embedder.embed(self.test_cropped_image)
|
||||
|
||||
# Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg".
|
||||
roi = _Rect(left=0, top=0, right=0.833333, bottom=1)
|
||||
image_processing_options = _ImageProcessingOptions(roi)
|
||||
observed_timestamp_ms = -1
|
||||
|
||||
def check_result(result: _ImageEmbedderResult, output_image: _Image,
|
||||
timestamp_ms: int):
|
||||
# Checks cosine similarity.
|
||||
self._check_cosine_similarity(
|
||||
result, crop_result, expected_similarity=0.999931)
|
||||
self.assertTrue(
|
||||
np.array_equal(output_image.numpy_view(),
|
||||
self.test_image.numpy_view()))
|
||||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||
self.observed_timestamp_ms = timestamp_ms
|
||||
|
||||
options = _ImageEmbedderOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=check_result)
|
||||
with _ImageEmbedder.create_from_options(options) as embedder:
|
||||
for timestamp in range(0, 300, 30):
|
||||
embedder.embed_async(self.test_image, timestamp,
|
||||
image_processing_options)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
+512
@@ -0,0 +1,512 @@
|
||||
# Copyright 2022 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for image segmenter."""
|
||||
|
||||
import enum
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.python._framework_bindings import image_frame
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
from mediapipe.tasks.python.vision import image_segmenter
|
||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
||||
|
||||
ImageSegmenterResult = image_segmenter.ImageSegmenterResult
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_Image = image_module.Image
|
||||
_ImageFormat = image_frame.ImageFormat
|
||||
_ImageSegmenter = image_segmenter.ImageSegmenter
|
||||
_ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions
|
||||
_RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode
|
||||
|
||||
_MODEL_FILE = 'deeplabv3.tflite'
|
||||
_IMAGE_FILE = 'segmentation_input_rotation0.jpg'
|
||||
_SEGMENTATION_FILE = 'segmentation_golden_rotation0.png'
|
||||
_CAT_IMAGE = 'cat.jpg'
|
||||
_CAT_MASK = 'cat_mask.jpg'
|
||||
_MASK_MAGNIFICATION_FACTOR = 10
|
||||
_MASK_SIMILARITY_THRESHOLD = 0.98
|
||||
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
||||
_EXPECTED_LABELS = [
|
||||
'background',
|
||||
'aeroplane',
|
||||
'bicycle',
|
||||
'bird',
|
||||
'boat',
|
||||
'bottle',
|
||||
'bus',
|
||||
'car',
|
||||
'cat',
|
||||
'chair',
|
||||
'cow',
|
||||
'dining table',
|
||||
'dog',
|
||||
'horse',
|
||||
'motorbike',
|
||||
'person',
|
||||
'potted plant',
|
||||
'sheep',
|
||||
'sofa',
|
||||
'train',
|
||||
'tv',
|
||||
]
|
||||
|
||||
|
||||
def _calculate_soft_iou(m1, m2):
|
||||
intersection_sum = np.sum(m1 * m2)
|
||||
union_sum = np.sum(m1 * m1) + np.sum(m2 * m2) - intersection_sum
|
||||
|
||||
if union_sum > 0:
|
||||
return intersection_sum / union_sum
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def _similar_to_float_mask(actual_mask, expected_mask, similarity_threshold):
|
||||
actual_mask = actual_mask.numpy_view()
|
||||
expected_mask = expected_mask.numpy_view() / 255.0
|
||||
|
||||
return (
|
||||
actual_mask.shape == expected_mask.shape
|
||||
and _calculate_soft_iou(actual_mask, expected_mask) > similarity_threshold
|
||||
)
|
||||
|
||||
|
||||
def _similar_to_uint8_mask(actual_mask, expected_mask):
|
||||
actual_mask_pixels = actual_mask.numpy_view().flatten()
|
||||
expected_mask_pixels = expected_mask.numpy_view().flatten()
|
||||
|
||||
consistent_pixels = 0
|
||||
num_pixels = len(expected_mask_pixels)
|
||||
|
||||
for index in range(num_pixels):
|
||||
consistent_pixels += (
|
||||
actual_mask_pixels[index] * _MASK_MAGNIFICATION_FACTOR
|
||||
== expected_mask_pixels[index]
|
||||
)
|
||||
|
||||
return consistent_pixels / num_pixels >= _MASK_SIMILARITY_THRESHOLD
|
||||
|
||||
|
||||
class ModelFileType(enum.Enum):
|
||||
FILE_CONTENT = 1
|
||||
FILE_NAME = 2
|
||||
|
||||
|
||||
class ImageSegmenterTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# Load the test input image.
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _IMAGE_FILE))
|
||||
)
|
||||
# Loads ground truth segmentation file.
|
||||
gt_segmentation_data = cv2.imread(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _SEGMENTATION_FILE)
|
||||
),
|
||||
cv2.IMREAD_GRAYSCALE,
|
||||
)
|
||||
self.test_seg_image = _Image(_ImageFormat.GRAY8, gt_segmentation_data)
|
||||
self.model_path = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _MODEL_FILE)
|
||||
)
|
||||
|
||||
def _load_segmentation_mask(self, file_path: str):
|
||||
# Loads ground truth segmentation file.
|
||||
gt_segmentation_data = cv2.imread(
|
||||
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_path)),
|
||||
cv2.IMREAD_GRAYSCALE,
|
||||
)
|
||||
return _Image(_ImageFormat.GRAY8, gt_segmentation_data)
|
||||
|
||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||
# Creates with default option and valid model file successfully.
|
||||
with _ImageSegmenter.create_from_model_path(self.model_path) as segmenter:
|
||||
self.assertIsInstance(segmenter, _ImageSegmenter)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_path(self):
|
||||
# Creates with options containing model file successfully.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _ImageSegmenterOptions(base_options=base_options)
|
||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||
self.assertIsInstance(segmenter, _ImageSegmenter)
|
||||
|
||||
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'
|
||||
):
|
||||
base_options = _BaseOptions(
|
||||
model_asset_path='/path/to/invalid/model.tflite'
|
||||
)
|
||||
options = _ImageSegmenterOptions(base_options=base_options)
|
||||
_ImageSegmenter.create_from_options(options)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_content(self):
|
||||
# Creates with options containing model content successfully.
|
||||
with open(self.model_path, 'rb') as f:
|
||||
base_options = _BaseOptions(model_asset_buffer=f.read())
|
||||
options = _ImageSegmenterOptions(base_options=base_options)
|
||||
segmenter = _ImageSegmenter.create_from_options(options)
|
||||
self.assertIsInstance(segmenter, _ImageSegmenter)
|
||||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME,), (ModelFileType.FILE_CONTENT,)
|
||||
)
|
||||
def test_segment_succeeds_with_category_mask(self, model_file_type):
|
||||
# Creates segmenter.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=base_options,
|
||||
output_category_mask=True,
|
||||
output_confidence_masks=False,
|
||||
)
|
||||
segmenter = _ImageSegmenter.create_from_options(options)
|
||||
|
||||
# Performs image segmentation on the input.
|
||||
segmentation_result = segmenter.segment(self.test_image)
|
||||
category_mask = segmentation_result.category_mask
|
||||
result_pixels = category_mask.numpy_view().flatten()
|
||||
|
||||
# Check if data type of `category_mask` is correct.
|
||||
self.assertEqual(result_pixels.dtype, np.uint8)
|
||||
|
||||
self.assertTrue(
|
||||
_similar_to_uint8_mask(category_mask, self.test_seg_image),
|
||||
(
|
||||
'Number of pixels in the candidate mask differing from that of the'
|
||||
f' ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.'
|
||||
),
|
||||
)
|
||||
|
||||
# Closes the segmenter explicitly when the segmenter is not used in
|
||||
# a context.
|
||||
segmenter.close()
|
||||
|
||||
def test_segment_succeeds_with_confidence_mask(self):
|
||||
# Creates segmenter.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
|
||||
# Load the cat image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))
|
||||
)
|
||||
|
||||
# Run segmentation on the model in CONFIDENCE_MASK mode.
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=base_options,
|
||||
output_category_mask=False,
|
||||
output_confidence_masks=True,
|
||||
)
|
||||
|
||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||
segmentation_result = segmenter.segment(test_image)
|
||||
confidence_masks = segmentation_result.confidence_masks
|
||||
|
||||
# Check if confidence mask shape is correct.
|
||||
self.assertLen(
|
||||
confidence_masks,
|
||||
21,
|
||||
'Number of confidence masks must match with number of categories.',
|
||||
)
|
||||
|
||||
# Loads ground truth segmentation file.
|
||||
expected_mask = self._load_segmentation_mask(_CAT_MASK)
|
||||
|
||||
self.assertTrue(
|
||||
_similar_to_float_mask(
|
||||
confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD
|
||||
)
|
||||
)
|
||||
|
||||
@parameterized.parameters((True, False), (False, True))
|
||||
def test_labels_succeeds(self, output_category_mask, output_confidence_masks):
|
||||
expected_labels = _EXPECTED_LABELS
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=base_options,
|
||||
output_category_mask=output_category_mask,
|
||||
output_confidence_masks=output_confidence_masks,
|
||||
)
|
||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||
# Performs image segmentation on the input.
|
||||
actual_labels = segmenter.labels
|
||||
self.assertListEqual(actual_labels, expected_labels)
|
||||
|
||||
def test_missing_result_callback(self):
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'result callback must be provided'
|
||||
):
|
||||
with _ImageSegmenter.create_from_options(options) as unused_segmenter:
|
||||
pass
|
||||
|
||||
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
|
||||
def test_illegal_result_callback(self, running_mode):
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=running_mode,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'result callback should not be provided'
|
||||
):
|
||||
with _ImageSegmenter.create_from_options(options) as unused_segmenter:
|
||||
pass
|
||||
|
||||
def test_calling_segment_for_video_in_image_mode(self):
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE,
|
||||
)
|
||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the video mode'
|
||||
):
|
||||
segmenter.segment_for_video(self.test_image, 0)
|
||||
|
||||
def test_calling_segment_async_in_image_mode(self):
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE,
|
||||
)
|
||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the live stream mode'
|
||||
):
|
||||
segmenter.segment_async(self.test_image, 0)
|
||||
|
||||
def test_calling_segment_in_video_mode(self):
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the image mode'
|
||||
):
|
||||
segmenter.segment(self.test_image)
|
||||
|
||||
def test_calling_segment_async_in_video_mode(self):
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the live stream mode'
|
||||
):
|
||||
segmenter.segment_async(self.test_image, 0)
|
||||
|
||||
def test_segment_for_video_with_out_of_order_timestamp(self):
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||
unused_result = segmenter.segment_for_video(self.test_image, 1)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'
|
||||
):
|
||||
segmenter.segment_for_video(self.test_image, 0)
|
||||
|
||||
def test_segment_for_video_in_category_mask_mode(self):
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
output_category_mask=True,
|
||||
output_confidence_masks=False,
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||
for timestamp in range(0, 300, 30):
|
||||
segmentation_result = segmenter.segment_for_video(
|
||||
self.test_image, timestamp
|
||||
)
|
||||
category_mask = segmentation_result.category_mask
|
||||
self.assertTrue(
|
||||
_similar_to_uint8_mask(category_mask, self.test_seg_image),
|
||||
(
|
||||
'Number of pixels in the candidate mask differing from that of'
|
||||
f' the ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.'
|
||||
),
|
||||
)
|
||||
|
||||
def test_segment_for_video_in_confidence_mask_mode(self):
|
||||
# Load the cat image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))
|
||||
)
|
||||
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
output_category_mask=False,
|
||||
output_confidence_masks=True,
|
||||
)
|
||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||
for timestamp in range(0, 300, 30):
|
||||
segmentation_result = segmenter.segment_for_video(test_image, timestamp)
|
||||
confidence_masks = segmentation_result.confidence_masks
|
||||
|
||||
# Check if confidence mask shape is correct.
|
||||
self.assertLen(
|
||||
confidence_masks,
|
||||
21,
|
||||
'Number of confidence masks must match with number of categories.',
|
||||
)
|
||||
|
||||
# Loads ground truth segmentation file.
|
||||
expected_mask = self._load_segmentation_mask(_CAT_MASK)
|
||||
self.assertTrue(
|
||||
_similar_to_float_mask(
|
||||
confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD
|
||||
)
|
||||
)
|
||||
|
||||
def test_calling_segment_in_live_stream_mode(self):
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the image mode'
|
||||
):
|
||||
segmenter.segment(self.test_image)
|
||||
|
||||
def test_calling_segment_for_video_in_live_stream_mode(self):
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the video mode'
|
||||
):
|
||||
segmenter.segment_for_video(self.test_image, 0)
|
||||
|
||||
def test_segment_async_calls_with_illegal_timestamp(self):
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||
segmenter.segment_async(self.test_image, 100)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'
|
||||
):
|
||||
segmenter.segment_async(self.test_image, 0)
|
||||
|
||||
def test_segment_async_calls_in_category_mask_mode(self):
|
||||
observed_timestamp_ms = -1
|
||||
|
||||
def check_result(
|
||||
result: ImageSegmenterResult, output_image: _Image, timestamp_ms: int
|
||||
):
|
||||
# Get the output category mask.
|
||||
category_mask = result.category_mask
|
||||
self.assertEqual(output_image.width, self.test_image.width)
|
||||
self.assertEqual(output_image.height, self.test_image.height)
|
||||
self.assertEqual(output_image.width, self.test_seg_image.width)
|
||||
self.assertEqual(output_image.height, self.test_seg_image.height)
|
||||
self.assertTrue(
|
||||
_similar_to_uint8_mask(category_mask, self.test_seg_image),
|
||||
(
|
||||
'Number of pixels in the candidate mask differing from that of'
|
||||
f' the ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.'
|
||||
),
|
||||
)
|
||||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||
self.observed_timestamp_ms = timestamp_ms
|
||||
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
output_category_mask=True,
|
||||
output_confidence_masks=False,
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=check_result,
|
||||
)
|
||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||
for timestamp in range(0, 300, 30):
|
||||
segmenter.segment_async(self.test_image, timestamp)
|
||||
|
||||
def test_segment_async_calls_in_confidence_mask_mode(self):
|
||||
# Load the cat image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))
|
||||
)
|
||||
|
||||
# Loads ground truth segmentation file.
|
||||
expected_mask = self._load_segmentation_mask(_CAT_MASK)
|
||||
observed_timestamp_ms = -1
|
||||
|
||||
def check_result(
|
||||
result: ImageSegmenterResult, output_image: _Image, timestamp_ms: int
|
||||
):
|
||||
# Get the output category mask.
|
||||
confidence_masks = result.confidence_masks
|
||||
|
||||
# Check if confidence mask shape is correct.
|
||||
self.assertLen(
|
||||
confidence_masks,
|
||||
21,
|
||||
'Number of confidence masks must match with number of categories.',
|
||||
)
|
||||
self.assertEqual(output_image.width, test_image.width)
|
||||
self.assertEqual(output_image.height, test_image.height)
|
||||
self.assertTrue(
|
||||
_similar_to_float_mask(
|
||||
confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD
|
||||
)
|
||||
)
|
||||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||
self.observed_timestamp_ms = timestamp_ms
|
||||
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
output_category_mask=False,
|
||||
output_confidence_masks=True,
|
||||
result_callback=check_result,
|
||||
)
|
||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||
for timestamp in range(0, 300, 30):
|
||||
segmenter.segment_async(test_image, timestamp)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
+341
@@ -0,0 +1,341 @@
|
||||
# Copyright 2023 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for interactive segmenter."""
|
||||
|
||||
import enum
|
||||
import os
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.python._framework_bindings import image_frame
|
||||
from mediapipe.tasks.python.components.containers import keypoint as keypoint_module
|
||||
from mediapipe.tasks.python.components.containers import rect
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
from mediapipe.tasks.python.vision import interactive_segmenter
|
||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
|
||||
InteractiveSegmenterResult = interactive_segmenter.InteractiveSegmenterResult
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_Image = image_module.Image
|
||||
_ImageFormat = image_frame.ImageFormat
|
||||
_NormalizedKeypoint = keypoint_module.NormalizedKeypoint
|
||||
_Rect = rect.Rect
|
||||
_InteractiveSegmenter = interactive_segmenter.InteractiveSegmenter
|
||||
_InteractiveSegmenterOptions = interactive_segmenter.InteractiveSegmenterOptions
|
||||
_RegionOfInterest = interactive_segmenter.RegionOfInterest
|
||||
_Format = interactive_segmenter.RegionOfInterest.Format
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
|
||||
_MODEL_FILE = 'ptm_512_hdt_ptm_woid.tflite'
|
||||
_CATS_AND_DOGS = 'cats_and_dogs.jpg'
|
||||
_CATS_AND_DOGS_MASK_DOG_1 = 'cats_and_dogs_mask_dog1.png'
|
||||
_CATS_AND_DOGS_MASK_DOG_2 = 'cats_and_dogs_mask_dog2.png'
|
||||
_MASK_MAGNIFICATION_FACTOR = 255
|
||||
_MASK_SIMILARITY_THRESHOLD = 0.97
|
||||
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
||||
|
||||
|
||||
def _calculate_soft_iou(m1, m2):
|
||||
intersection_sum = np.sum(m1 * m2)
|
||||
union_sum = np.sum(m1 * m1) + np.sum(m2 * m2) - intersection_sum
|
||||
|
||||
if union_sum > 0:
|
||||
return intersection_sum / union_sum
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def _similar_to_float_mask(actual_mask, expected_mask, similarity_threshold):
|
||||
actual_mask = actual_mask.numpy_view()
|
||||
expected_mask = expected_mask.numpy_view() / 255.0
|
||||
|
||||
return (
|
||||
actual_mask.shape == expected_mask.shape
|
||||
and _calculate_soft_iou(actual_mask, expected_mask) > similarity_threshold
|
||||
)
|
||||
|
||||
|
||||
def _similar_to_uint8_mask(actual_mask, expected_mask, similarity_threshold):
|
||||
actual_mask_pixels = actual_mask.numpy_view().flatten()
|
||||
expected_mask_pixels = expected_mask.numpy_view().flatten()
|
||||
|
||||
consistent_pixels = 0
|
||||
num_pixels = len(expected_mask_pixels)
|
||||
|
||||
for index in range(num_pixels):
|
||||
consistent_pixels += (
|
||||
actual_mask_pixels[index] * _MASK_MAGNIFICATION_FACTOR
|
||||
== expected_mask_pixels[index]
|
||||
)
|
||||
|
||||
return consistent_pixels / num_pixels >= similarity_threshold
|
||||
|
||||
|
||||
class ModelFileType(enum.Enum):
|
||||
FILE_CONTENT = 1
|
||||
FILE_NAME = 2
|
||||
|
||||
|
||||
class InteractiveSegmenterTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# Load the test input image.
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _CATS_AND_DOGS)
|
||||
)
|
||||
)
|
||||
# Loads ground truth segmentation file.
|
||||
self.test_seg_image = self._load_segmentation_mask(
|
||||
_CATS_AND_DOGS_MASK_DOG_1
|
||||
)
|
||||
self.model_path = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _MODEL_FILE)
|
||||
)
|
||||
|
||||
def _load_segmentation_mask(self, file_path: str):
|
||||
# Loads ground truth segmentation file.
|
||||
gt_segmentation_data = cv2.imread(
|
||||
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_path)),
|
||||
cv2.IMREAD_GRAYSCALE,
|
||||
)
|
||||
return _Image(_ImageFormat.GRAY8, gt_segmentation_data)
|
||||
|
||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||
# Creates with default option and valid model file successfully.
|
||||
with _InteractiveSegmenter.create_from_model_path(
|
||||
self.model_path
|
||||
) as segmenter:
|
||||
self.assertIsInstance(segmenter, _InteractiveSegmenter)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_path(self):
|
||||
# Creates with options containing model file successfully.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _InteractiveSegmenterOptions(base_options=base_options)
|
||||
with _InteractiveSegmenter.create_from_options(options) as segmenter:
|
||||
self.assertIsInstance(segmenter, _InteractiveSegmenter)
|
||||
|
||||
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'
|
||||
):
|
||||
base_options = _BaseOptions(
|
||||
model_asset_path='/path/to/invalid/model.tflite'
|
||||
)
|
||||
options = _InteractiveSegmenterOptions(base_options=base_options)
|
||||
_InteractiveSegmenter.create_from_options(options)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_content(self):
|
||||
# Creates with options containing model content successfully.
|
||||
with open(self.model_path, 'rb') as f:
|
||||
base_options = _BaseOptions(model_asset_buffer=f.read())
|
||||
options = _InteractiveSegmenterOptions(base_options=base_options)
|
||||
segmenter = _InteractiveSegmenter.create_from_options(options)
|
||||
self.assertIsInstance(segmenter, _InteractiveSegmenter)
|
||||
|
||||
@parameterized.parameters(
|
||||
(
|
||||
ModelFileType.FILE_NAME,
|
||||
_RegionOfInterest.Format.KEYPOINT,
|
||||
_NormalizedKeypoint(0.44, 0.7),
|
||||
_CATS_AND_DOGS_MASK_DOG_1,
|
||||
0.84,
|
||||
),
|
||||
(
|
||||
ModelFileType.FILE_CONTENT,
|
||||
_RegionOfInterest.Format.KEYPOINT,
|
||||
_NormalizedKeypoint(0.44, 0.7),
|
||||
_CATS_AND_DOGS_MASK_DOG_1,
|
||||
0.84,
|
||||
),
|
||||
(
|
||||
ModelFileType.FILE_NAME,
|
||||
_RegionOfInterest.Format.KEYPOINT,
|
||||
_NormalizedKeypoint(0.66, 0.66),
|
||||
_CATS_AND_DOGS_MASK_DOG_2,
|
||||
_MASK_SIMILARITY_THRESHOLD,
|
||||
),
|
||||
(
|
||||
ModelFileType.FILE_CONTENT,
|
||||
_RegionOfInterest.Format.KEYPOINT,
|
||||
_NormalizedKeypoint(0.66, 0.66),
|
||||
_CATS_AND_DOGS_MASK_DOG_2,
|
||||
_MASK_SIMILARITY_THRESHOLD,
|
||||
),
|
||||
)
|
||||
def test_segment_succeeds_with_category_mask(
|
||||
self,
|
||||
model_file_type,
|
||||
roi_format,
|
||||
keypoint,
|
||||
output_mask,
|
||||
similarity_threshold,
|
||||
):
|
||||
# Creates segmenter.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _InteractiveSegmenterOptions(
|
||||
base_options=base_options,
|
||||
output_category_mask=True,
|
||||
output_confidence_masks=False,
|
||||
)
|
||||
segmenter = _InteractiveSegmenter.create_from_options(options)
|
||||
|
||||
# Performs image segmentation on the input.
|
||||
roi = _RegionOfInterest(format=roi_format, keypoint=keypoint)
|
||||
segmentation_result = segmenter.segment(self.test_image, roi)
|
||||
category_mask = segmentation_result.category_mask
|
||||
result_pixels = category_mask.numpy_view().flatten()
|
||||
|
||||
# Check if data type of `category_mask` is correct.
|
||||
self.assertEqual(result_pixels.dtype, np.uint8)
|
||||
|
||||
# Loads ground truth segmentation file.
|
||||
test_seg_image = self._load_segmentation_mask(output_mask)
|
||||
|
||||
self.assertTrue(
|
||||
_similar_to_uint8_mask(
|
||||
category_mask, test_seg_image, similarity_threshold
|
||||
),
|
||||
(
|
||||
'Number of pixels in the candidate mask differing from that of the'
|
||||
f' ground truth mask exceeds {similarity_threshold}.'
|
||||
),
|
||||
)
|
||||
|
||||
# Closes the segmenter explicitly when the segmenter is not used in
|
||||
# a context.
|
||||
segmenter.close()
|
||||
|
||||
@parameterized.parameters(
|
||||
(
|
||||
_RegionOfInterest.Format.KEYPOINT,
|
||||
_NormalizedKeypoint(0.44, 0.7),
|
||||
_CATS_AND_DOGS_MASK_DOG_1,
|
||||
0.84,
|
||||
),
|
||||
(
|
||||
_RegionOfInterest.Format.KEYPOINT,
|
||||
_NormalizedKeypoint(0.66, 0.66),
|
||||
_CATS_AND_DOGS_MASK_DOG_2,
|
||||
_MASK_SIMILARITY_THRESHOLD,
|
||||
),
|
||||
)
|
||||
def test_segment_succeeds_with_confidence_mask(
|
||||
self, roi_format, keypoint, output_mask, similarity_threshold
|
||||
):
|
||||
# Creates segmenter.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
roi = _RegionOfInterest(format=roi_format, keypoint=keypoint)
|
||||
|
||||
# Run segmentation on the model in CONFIDENCE_MASK mode.
|
||||
options = _InteractiveSegmenterOptions(
|
||||
base_options=base_options,
|
||||
output_category_mask=False,
|
||||
output_confidence_masks=True,
|
||||
)
|
||||
|
||||
with _InteractiveSegmenter.create_from_options(options) as segmenter:
|
||||
# Perform segmentation
|
||||
segmentation_result = segmenter.segment(self.test_image, roi)
|
||||
confidence_masks = segmentation_result.confidence_masks
|
||||
|
||||
# Check if confidence mask shape is correct.
|
||||
self.assertLen(
|
||||
confidence_masks,
|
||||
2,
|
||||
'Number of confidence masks must match with number of categories.',
|
||||
)
|
||||
|
||||
# Loads ground truth segmentation file.
|
||||
expected_mask = self._load_segmentation_mask(output_mask)
|
||||
|
||||
self.assertTrue(
|
||||
_similar_to_float_mask(
|
||||
confidence_masks[1], expected_mask, similarity_threshold
|
||||
)
|
||||
)
|
||||
|
||||
def test_segment_succeeds_with_rotation(self):
|
||||
# Creates segmenter.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
roi = _RegionOfInterest(
|
||||
format=_RegionOfInterest.Format.KEYPOINT,
|
||||
keypoint=_NormalizedKeypoint(0.66, 0.66),
|
||||
)
|
||||
|
||||
# Run segmentation on the model in CONFIDENCE_MASK mode.
|
||||
options = _InteractiveSegmenterOptions(
|
||||
base_options=base_options,
|
||||
output_category_mask=False,
|
||||
output_confidence_masks=True,
|
||||
)
|
||||
|
||||
with _InteractiveSegmenter.create_from_options(options) as segmenter:
|
||||
# Perform segmentation
|
||||
image_processing_options = _ImageProcessingOptions(rotation_degrees=-90)
|
||||
segmentation_result = segmenter.segment(
|
||||
self.test_image, roi, image_processing_options
|
||||
)
|
||||
confidence_masks = segmentation_result.confidence_masks
|
||||
|
||||
# Check if confidence mask shape is correct.
|
||||
self.assertLen(
|
||||
confidence_masks,
|
||||
2,
|
||||
'Number of confidence masks must match with number of categories.',
|
||||
)
|
||||
|
||||
def test_segment_fails_with_roi_in_image_processing_options(self):
|
||||
# Creates segmenter.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
roi = _RegionOfInterest(
|
||||
format=_RegionOfInterest.Format.KEYPOINT,
|
||||
keypoint=_NormalizedKeypoint(0.66, 0.66),
|
||||
)
|
||||
|
||||
# Run segmentation on the model in CONFIDENCE_MASK mode.
|
||||
options = _InteractiveSegmenterOptions(
|
||||
base_options=base_options,
|
||||
output_category_mask=False,
|
||||
output_confidence_masks=True,
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "This task doesn't support region-of-interest."
|
||||
):
|
||||
with _InteractiveSegmenter.create_from_options(options) as segmenter:
|
||||
# Perform segmentation
|
||||
image_processing_options = _ImageProcessingOptions(
|
||||
_Rect(left=0.1, top=0, right=0.9, bottom=1)
|
||||
)
|
||||
segmenter.segment(self.test_image, roi, image_processing_options)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
+493
@@ -0,0 +1,493 @@
|
||||
# Copyright 2022 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for object detector."""
|
||||
|
||||
import enum
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.tasks.python.components.containers import bounding_box as bounding_box_module
|
||||
from mediapipe.tasks.python.components.containers import category as category_module
|
||||
from mediapipe.tasks.python.components.containers import detections as detections_module
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
from mediapipe.tasks.python.vision import object_detector
|
||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_Category = category_module.Category
|
||||
_BoundingBox = bounding_box_module.BoundingBox
|
||||
_Detection = detections_module.Detection
|
||||
_DetectionResult = detections_module.DetectionResult
|
||||
_Image = image_module.Image
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
_ObjectDetector = object_detector.ObjectDetector
|
||||
_ObjectDetectorOptions = object_detector.ObjectDetectorOptions
|
||||
|
||||
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
|
||||
|
||||
_MODEL_FILE = 'coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite'
|
||||
_NO_NMS_MODEL_FILE = 'efficientdet_lite0_fp16_no_nms.tflite'
|
||||
_IMAGE_FILE = 'cats_and_dogs.jpg'
|
||||
_EXPECTED_DETECTION_RESULT = _DetectionResult(
|
||||
detections=[
|
||||
_Detection(
|
||||
bounding_box=_BoundingBox(
|
||||
origin_x=608,
|
||||
origin_y=164,
|
||||
width=381,
|
||||
height=432,
|
||||
),
|
||||
categories=[
|
||||
_Category(
|
||||
index=None,
|
||||
score=0.69921875,
|
||||
display_name=None,
|
||||
category_name='cat',
|
||||
)
|
||||
],
|
||||
),
|
||||
_Detection(
|
||||
bounding_box=_BoundingBox(
|
||||
origin_x=57,
|
||||
origin_y=398,
|
||||
width=386,
|
||||
height=196,
|
||||
),
|
||||
categories=[
|
||||
_Category(
|
||||
index=None,
|
||||
score=0.65625,
|
||||
display_name=None,
|
||||
category_name='cat',
|
||||
)
|
||||
],
|
||||
),
|
||||
_Detection(
|
||||
bounding_box=_BoundingBox(
|
||||
origin_x=256,
|
||||
origin_y=394,
|
||||
width=173,
|
||||
height=202,
|
||||
),
|
||||
categories=[
|
||||
_Category(
|
||||
index=None,
|
||||
score=0.51171875,
|
||||
display_name=None,
|
||||
category_name='cat',
|
||||
)
|
||||
],
|
||||
),
|
||||
_Detection(
|
||||
bounding_box=_BoundingBox(
|
||||
origin_x=360,
|
||||
origin_y=195,
|
||||
width=330,
|
||||
height=412,
|
||||
),
|
||||
categories=[
|
||||
_Category(
|
||||
index=None,
|
||||
score=0.48828125,
|
||||
display_name=None,
|
||||
category_name='cat',
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
_ALLOW_LIST = ['cat', 'dog']
|
||||
_DENY_LIST = ['cat']
|
||||
_SCORE_THRESHOLD = 0.3
|
||||
_MAX_RESULTS = 3
|
||||
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
||||
|
||||
|
||||
class ModelFileType(enum.Enum):
|
||||
FILE_CONTENT = 1
|
||||
FILE_NAME = 2
|
||||
|
||||
|
||||
class ObjectDetectorTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _IMAGE_FILE))
|
||||
)
|
||||
self.model_path = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _MODEL_FILE)
|
||||
)
|
||||
|
||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||
# Creates with default option and valid model file successfully.
|
||||
with _ObjectDetector.create_from_model_path(self.model_path) as detector:
|
||||
self.assertIsInstance(detector, _ObjectDetector)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_path(self):
|
||||
# Creates with options containing model file successfully.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _ObjectDetectorOptions(base_options=base_options)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
self.assertIsInstance(detector, _ObjectDetector)
|
||||
|
||||
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'
|
||||
):
|
||||
base_options = _BaseOptions(
|
||||
model_asset_path='/path/to/invalid/model.tflite'
|
||||
)
|
||||
options = _ObjectDetectorOptions(base_options=base_options)
|
||||
_ObjectDetector.create_from_options(options)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_content(self):
|
||||
# Creates with options containing model content successfully.
|
||||
with open(self.model_path, 'rb') as f:
|
||||
base_options = _BaseOptions(model_asset_buffer=f.read())
|
||||
options = _ObjectDetectorOptions(base_options=base_options)
|
||||
detector = _ObjectDetector.create_from_options(options)
|
||||
self.assertIsInstance(detector, _ObjectDetector)
|
||||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME, 4, _EXPECTED_DETECTION_RESULT),
|
||||
(ModelFileType.FILE_CONTENT, 4, _EXPECTED_DETECTION_RESULT),
|
||||
)
|
||||
def test_detect(
|
||||
self, model_file_type, max_results, expected_detection_result
|
||||
):
|
||||
# Creates detector.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=base_options, max_results=max_results
|
||||
)
|
||||
detector = _ObjectDetector.create_from_options(options)
|
||||
|
||||
# Performs object detection on the input.
|
||||
detection_result = detector.detect(self.test_image)
|
||||
# Comparing results.
|
||||
self.assertEqual(detection_result, expected_detection_result)
|
||||
# Closes the detector explicitly when the detector is not used in
|
||||
# a context.
|
||||
detector.close()
|
||||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME, 4, _EXPECTED_DETECTION_RESULT),
|
||||
(ModelFileType.FILE_CONTENT, 4, _EXPECTED_DETECTION_RESULT),
|
||||
)
|
||||
def test_detect_in_context(
|
||||
self, model_file_type, max_results, expected_detection_result
|
||||
):
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_contents = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_contents)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=base_options, max_results=max_results
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
# Performs object detection on the input.
|
||||
detection_result = detector.detect(self.test_image)
|
||||
# Comparing results.
|
||||
self.assertEqual(detection_result, expected_detection_result)
|
||||
|
||||
def test_score_threshold_option(self):
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
score_threshold=_SCORE_THRESHOLD,
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
# Performs object detection on the input.
|
||||
detection_result = detector.detect(self.test_image)
|
||||
detections = detection_result.detections
|
||||
|
||||
for detection in detections:
|
||||
score = detection.categories[0].score
|
||||
self.assertGreaterEqual(
|
||||
score,
|
||||
_SCORE_THRESHOLD,
|
||||
f'Detection with score lower than threshold found. {detection}',
|
||||
)
|
||||
|
||||
def test_max_results_option(self):
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
max_results=_MAX_RESULTS,
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
# Performs object detection on the input.
|
||||
detection_result = detector.detect(self.test_image)
|
||||
detections = detection_result.detections
|
||||
|
||||
self.assertLessEqual(
|
||||
len(detections), _MAX_RESULTS, 'Too many results returned.'
|
||||
)
|
||||
|
||||
def test_allow_list_option(self):
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
category_allowlist=_ALLOW_LIST,
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
# Performs object detection on the input.
|
||||
detection_result = detector.detect(self.test_image)
|
||||
detections = detection_result.detections
|
||||
|
||||
for detection in detections:
|
||||
label = detection.categories[0].category_name
|
||||
self.assertIn(
|
||||
label,
|
||||
_ALLOW_LIST,
|
||||
f'Label {label} found but not in label allow list',
|
||||
)
|
||||
|
||||
def test_deny_list_option(self):
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
category_denylist=_DENY_LIST,
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
# Performs object detection on the input.
|
||||
detection_result = detector.detect(self.test_image)
|
||||
detections = detection_result.detections
|
||||
|
||||
for detection in detections:
|
||||
label = detection.categories[0].category_name
|
||||
self.assertNotIn(
|
||||
label, _DENY_LIST, f'Label {label} found but in deny list.'
|
||||
)
|
||||
|
||||
def test_combined_allowlist_and_denylist(self):
|
||||
# Fails with combined allowlist and denylist
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r'`category_allowlist` and `category_denylist` are mutually '
|
||||
r'exclusive options.',
|
||||
):
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
category_allowlist=['foo'],
|
||||
category_denylist=['bar'],
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as unused_detector:
|
||||
pass
|
||||
|
||||
def test_empty_detection_outputs_with_in_model_nms(self):
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
score_threshold=1,
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
# Performs object detection on the input.
|
||||
detection_result = detector.detect(self.test_image)
|
||||
self.assertEmpty(detection_result.detections)
|
||||
|
||||
def test_empty_detection_outputs_without_in_model_nms(self):
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(
|
||||
model_asset_path=test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _NO_NMS_MODEL_FILE))),
|
||||
score_threshold=1,
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
# Performs object detection on the input.
|
||||
detection_result = detector.detect(self.test_image)
|
||||
self.assertEmpty(detection_result.detections)
|
||||
|
||||
def test_missing_result_callback(self):
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'result callback must be provided'
|
||||
):
|
||||
with _ObjectDetector.create_from_options(options) as unused_detector:
|
||||
pass
|
||||
|
||||
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
|
||||
def test_illegal_result_callback(self, running_mode):
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=running_mode,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'result callback should not be provided'
|
||||
):
|
||||
with _ObjectDetector.create_from_options(options) as unused_detector:
|
||||
pass
|
||||
|
||||
def test_calling_detect_for_video_in_image_mode(self):
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE,
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the video mode'
|
||||
):
|
||||
detector.detect_for_video(self.test_image, 0)
|
||||
|
||||
def test_calling_detect_async_in_image_mode(self):
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE,
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the live stream mode'
|
||||
):
|
||||
detector.detect_async(self.test_image, 0)
|
||||
|
||||
def test_calling_detect_in_video_mode(self):
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the image mode'
|
||||
):
|
||||
detector.detect(self.test_image)
|
||||
|
||||
def test_calling_detect_async_in_video_mode(self):
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the live stream mode'
|
||||
):
|
||||
detector.detect_async(self.test_image, 0)
|
||||
|
||||
def test_detect_for_video_with_out_of_order_timestamp(self):
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
unused_result = detector.detect_for_video(self.test_image, 1)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'
|
||||
):
|
||||
detector.detect_for_video(self.test_image, 0)
|
||||
|
||||
# TODO: Tests how `detect_for_video` handles the temporal data
|
||||
# with a real video.
|
||||
def test_detect_for_video(self):
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
max_results=4,
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
for timestamp in range(0, 300, 30):
|
||||
detection_result = detector.detect_for_video(self.test_image, timestamp)
|
||||
self.assertEqual(detection_result, _EXPECTED_DETECTION_RESULT)
|
||||
|
||||
def test_calling_detect_in_live_stream_mode(self):
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the image mode'
|
||||
):
|
||||
detector.detect(self.test_image)
|
||||
|
||||
def test_calling_detect_for_video_in_live_stream_mode(self):
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the video mode'
|
||||
):
|
||||
detector.detect_for_video(self.test_image, 0)
|
||||
|
||||
def test_detect_async_calls_with_illegal_timestamp(self):
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
max_results=4,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _ObjectDetector.create_from_options(options) as detector:
|
||||
detector.detect_async(self.test_image, 100)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'
|
||||
):
|
||||
detector.detect_async(self.test_image, 0)
|
||||
|
||||
@parameterized.parameters(
|
||||
(0, _EXPECTED_DETECTION_RESULT), (1, _DetectionResult(detections=[]))
|
||||
)
|
||||
def test_detect_async_calls(self, threshold, expected_result):
|
||||
observed_timestamp_ms = -1
|
||||
|
||||
def check_result(
|
||||
result: _DetectionResult, output_image: _Image, timestamp_ms: int
|
||||
):
|
||||
self.assertEqual(result, expected_result)
|
||||
self.assertTrue(
|
||||
np.array_equal(
|
||||
output_image.numpy_view(), self.test_image.numpy_view()
|
||||
)
|
||||
)
|
||||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||
self.observed_timestamp_ms = timestamp_ms
|
||||
|
||||
options = _ObjectDetectorOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
max_results=4,
|
||||
score_threshold=threshold,
|
||||
result_callback=check_result,
|
||||
)
|
||||
detector = _ObjectDetector.create_from_options(options)
|
||||
for timestamp in range(0, 300, 30):
|
||||
detector.detect_async(self.test_image, timestamp)
|
||||
detector.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
+518
@@ -0,0 +1,518 @@
|
||||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for pose landmarker."""
|
||||
|
||||
import enum
|
||||
from typing import List
|
||||
from unittest import mock
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from google.protobuf import text_format
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.tasks.cc.components.containers.proto import landmarks_detection_result_pb2
|
||||
from mediapipe.tasks.python.components.containers import landmark as landmark_module
|
||||
from mediapipe.tasks.python.components.containers import landmark_detection_result as landmark_detection_result_module
|
||||
from mediapipe.tasks.python.components.containers import rect as rect_module
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
from mediapipe.tasks.python.vision import pose_landmarker
|
||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||
|
||||
PoseLandmarkerResult = pose_landmarker.PoseLandmarkerResult
|
||||
_LandmarksDetectionResultProto = (
|
||||
landmarks_detection_result_pb2.LandmarksDetectionResult
|
||||
)
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_Rect = rect_module.Rect
|
||||
_Landmark = landmark_module.Landmark
|
||||
_NormalizedLandmark = landmark_module.NormalizedLandmark
|
||||
_LandmarksDetectionResult = (
|
||||
landmark_detection_result_module.LandmarksDetectionResult
|
||||
)
|
||||
_Image = image_module.Image
|
||||
_PoseLandmarker = pose_landmarker.PoseLandmarker
|
||||
_PoseLandmarkerOptions = pose_landmarker.PoseLandmarkerOptions
|
||||
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
|
||||
_POSE_LANDMARKER_BUNDLE_ASSET_FILE = 'pose_landmarker.task'
|
||||
_BURGER_IMAGE = 'burger.jpg'
|
||||
_POSE_IMAGE = 'pose.jpg'
|
||||
_POSE_LANDMARKS = 'pose_landmarks.pbtxt'
|
||||
_LANDMARKS_MARGIN = 0.03
|
||||
|
||||
|
||||
def _get_expected_pose_landmarker_result(
|
||||
file_path: str,
|
||||
) -> PoseLandmarkerResult:
|
||||
landmarks_detection_result_file_path = test_utils.get_test_data_path(
|
||||
file_path
|
||||
)
|
||||
with open(landmarks_detection_result_file_path, 'rb') as f:
|
||||
landmarks_detection_result_proto = _LandmarksDetectionResultProto()
|
||||
# Use this if a .pb file is available.
|
||||
# landmarks_detection_result_proto.ParseFromString(f.read())
|
||||
text_format.Parse(f.read(), landmarks_detection_result_proto)
|
||||
landmarks_detection_result = _LandmarksDetectionResult.create_from_pb2(
|
||||
landmarks_detection_result_proto
|
||||
)
|
||||
return PoseLandmarkerResult(
|
||||
pose_landmarks=[landmarks_detection_result.landmarks],
|
||||
pose_world_landmarks=[],
|
||||
)
|
||||
|
||||
|
||||
class ModelFileType(enum.Enum):
|
||||
FILE_CONTENT = 1
|
||||
FILE_NAME = 2
|
||||
|
||||
|
||||
class PoseLandmarkerTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(_POSE_IMAGE)
|
||||
)
|
||||
self.model_path = test_utils.get_test_data_path(
|
||||
_POSE_LANDMARKER_BUNDLE_ASSET_FILE
|
||||
)
|
||||
|
||||
def _expect_pose_landmarks_correct(
|
||||
self, actual_landmarks, expected_landmarks, margin
|
||||
):
|
||||
# Expects to have the same number of poses detected.
|
||||
self.assertLen(actual_landmarks, len(expected_landmarks))
|
||||
|
||||
for i, _ in enumerate(actual_landmarks):
|
||||
for j, elem in enumerate(actual_landmarks[i]):
|
||||
self.assertAlmostEqual(elem.x, expected_landmarks[i][j].x, delta=margin)
|
||||
self.assertAlmostEqual(elem.y, expected_landmarks[i][j].y, delta=margin)
|
||||
|
||||
def _expect_pose_landmarker_results_correct(
|
||||
self,
|
||||
actual_result: PoseLandmarkerResult,
|
||||
expected_result: PoseLandmarkerResult,
|
||||
output_segmentation_masks: bool,
|
||||
margin: float,
|
||||
):
|
||||
self._expect_pose_landmarks_correct(
|
||||
actual_result.pose_landmarks, expected_result.pose_landmarks, margin
|
||||
)
|
||||
if output_segmentation_masks:
|
||||
self.assertIsInstance(actual_result.segmentation_masks, List)
|
||||
for _, mask in enumerate(actual_result.segmentation_masks):
|
||||
self.assertIsInstance(mask, _Image)
|
||||
else:
|
||||
self.assertIsNone(actual_result.segmentation_masks)
|
||||
|
||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||
# Creates with default option and valid model file successfully.
|
||||
with _PoseLandmarker.create_from_model_path(self.model_path) as landmarker:
|
||||
self.assertIsInstance(landmarker, _PoseLandmarker)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_path(self):
|
||||
# Creates with options containing model file successfully.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _PoseLandmarkerOptions(base_options=base_options)
|
||||
with _PoseLandmarker.create_from_options(options) as landmarker:
|
||||
self.assertIsInstance(landmarker, _PoseLandmarker)
|
||||
|
||||
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||
# Invalid empty model path.
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'
|
||||
):
|
||||
base_options = _BaseOptions(
|
||||
model_asset_path='/path/to/invalid/model.tflite'
|
||||
)
|
||||
options = _PoseLandmarkerOptions(base_options=base_options)
|
||||
_PoseLandmarker.create_from_options(options)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_content(self):
|
||||
# Creates with options containing model content successfully.
|
||||
with open(self.model_path, 'rb') as f:
|
||||
base_options = _BaseOptions(model_asset_buffer=f.read())
|
||||
options = _PoseLandmarkerOptions(base_options=base_options)
|
||||
landmarker = _PoseLandmarker.create_from_options(options)
|
||||
self.assertIsInstance(landmarker, _PoseLandmarker)
|
||||
|
||||
@parameterized.parameters(
|
||||
(
|
||||
ModelFileType.FILE_NAME,
|
||||
False,
|
||||
_get_expected_pose_landmarker_result(_POSE_LANDMARKS),
|
||||
),
|
||||
(
|
||||
ModelFileType.FILE_CONTENT,
|
||||
False,
|
||||
_get_expected_pose_landmarker_result(_POSE_LANDMARKS),
|
||||
),
|
||||
(
|
||||
ModelFileType.FILE_NAME,
|
||||
True,
|
||||
_get_expected_pose_landmarker_result(_POSE_LANDMARKS),
|
||||
),
|
||||
(
|
||||
ModelFileType.FILE_CONTENT,
|
||||
True,
|
||||
_get_expected_pose_landmarker_result(_POSE_LANDMARKS),
|
||||
),
|
||||
)
|
||||
def test_detect(
|
||||
self,
|
||||
model_file_type,
|
||||
output_segmentation_masks,
|
||||
expected_detection_result,
|
||||
):
|
||||
# Creates pose landmarker.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _PoseLandmarkerOptions(
|
||||
base_options=base_options,
|
||||
output_segmentation_masks=output_segmentation_masks,
|
||||
)
|
||||
landmarker = _PoseLandmarker.create_from_options(options)
|
||||
|
||||
# Performs pose landmarks detection on the input.
|
||||
detection_result = landmarker.detect(self.test_image)
|
||||
|
||||
# Comparing results.
|
||||
self._expect_pose_landmarker_results_correct(
|
||||
detection_result,
|
||||
expected_detection_result,
|
||||
output_segmentation_masks,
|
||||
_LANDMARKS_MARGIN,
|
||||
)
|
||||
# Closes the pose landmarker explicitly when the pose landmarker is not used
|
||||
# in a context.
|
||||
landmarker.close()
|
||||
|
||||
@parameterized.parameters(
|
||||
(
|
||||
ModelFileType.FILE_NAME,
|
||||
False,
|
||||
_get_expected_pose_landmarker_result(_POSE_LANDMARKS),
|
||||
),
|
||||
(
|
||||
ModelFileType.FILE_CONTENT,
|
||||
False,
|
||||
_get_expected_pose_landmarker_result(_POSE_LANDMARKS),
|
||||
),
|
||||
(
|
||||
ModelFileType.FILE_NAME,
|
||||
True,
|
||||
_get_expected_pose_landmarker_result(_POSE_LANDMARKS),
|
||||
),
|
||||
(
|
||||
ModelFileType.FILE_CONTENT,
|
||||
True,
|
||||
_get_expected_pose_landmarker_result(_POSE_LANDMARKS),
|
||||
),
|
||||
)
|
||||
def test_detect_in_context(
|
||||
self,
|
||||
model_file_type,
|
||||
output_segmentation_masks,
|
||||
expected_detection_result,
|
||||
):
|
||||
# Creates pose landmarker.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _PoseLandmarkerOptions(
|
||||
base_options=base_options,
|
||||
output_segmentation_masks=output_segmentation_masks,
|
||||
)
|
||||
with _PoseLandmarker.create_from_options(options) as landmarker:
|
||||
# Performs pose landmarks detection on the input.
|
||||
detection_result = landmarker.detect(self.test_image)
|
||||
|
||||
# Comparing results.
|
||||
self._expect_pose_landmarker_results_correct(
|
||||
detection_result,
|
||||
expected_detection_result,
|
||||
output_segmentation_masks,
|
||||
_LANDMARKS_MARGIN,
|
||||
)
|
||||
|
||||
def test_detect_fails_with_region_of_interest(self):
|
||||
# Creates pose landmarker.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _PoseLandmarkerOptions(base_options=base_options)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "This task doesn't support region-of-interest."
|
||||
):
|
||||
with _PoseLandmarker.create_from_options(options) as landmarker:
|
||||
# Set the `region_of_interest` parameter using `ImageProcessingOptions`.
|
||||
image_processing_options = _ImageProcessingOptions(
|
||||
region_of_interest=_Rect(0, 0, 1, 1)
|
||||
)
|
||||
# Attempt to perform pose landmarks detection on the cropped input.
|
||||
landmarker.detect(self.test_image, image_processing_options)
|
||||
|
||||
def test_empty_detection_outputs(self):
|
||||
# Creates pose landmarker.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _PoseLandmarkerOptions(base_options=base_options)
|
||||
with _PoseLandmarker.create_from_options(options) as landmarker:
|
||||
# Load an image with no poses.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(_BURGER_IMAGE)
|
||||
)
|
||||
# Performs pose landmarks detection on the input.
|
||||
detection_result = landmarker.detect(test_image)
|
||||
# Comparing results.
|
||||
self.assertEmpty(detection_result.pose_landmarks)
|
||||
self.assertEmpty(detection_result.pose_world_landmarks)
|
||||
|
||||
def test_missing_result_callback(self):
|
||||
options = _PoseLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'result callback must be provided'
|
||||
):
|
||||
with _PoseLandmarker.create_from_options(options) as unused_landmarker:
|
||||
pass
|
||||
|
||||
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
|
||||
def test_illegal_result_callback(self, running_mode):
|
||||
options = _PoseLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=running_mode,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'result callback should not be provided'
|
||||
):
|
||||
with _PoseLandmarker.create_from_options(options) as unused_landmarker:
|
||||
pass
|
||||
|
||||
def test_calling_detect_for_video_in_image_mode(self):
|
||||
options = _PoseLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE,
|
||||
)
|
||||
with _PoseLandmarker.create_from_options(options) as landmarker:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the video mode'
|
||||
):
|
||||
landmarker.detect_for_video(self.test_image, 0)
|
||||
|
||||
def test_calling_detect_async_in_image_mode(self):
|
||||
options = _PoseLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE,
|
||||
)
|
||||
with _PoseLandmarker.create_from_options(options) as landmarker:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the live stream mode'
|
||||
):
|
||||
landmarker.detect_async(self.test_image, 0)
|
||||
|
||||
def test_calling_detect_in_video_mode(self):
|
||||
options = _PoseLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _PoseLandmarker.create_from_options(options) as landmarker:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the image mode'
|
||||
):
|
||||
landmarker.detect(self.test_image)
|
||||
|
||||
def test_calling_detect_async_in_video_mode(self):
|
||||
options = _PoseLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _PoseLandmarker.create_from_options(options) as landmarker:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the live stream mode'
|
||||
):
|
||||
landmarker.detect_async(self.test_image, 0)
|
||||
|
||||
def test_detect_for_video_with_out_of_order_timestamp(self):
|
||||
options = _PoseLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _PoseLandmarker.create_from_options(options) as landmarker:
|
||||
unused_result = landmarker.detect_for_video(self.test_image, 1)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'
|
||||
):
|
||||
landmarker.detect_for_video(self.test_image, 0)
|
||||
|
||||
@parameterized.parameters(
|
||||
(
|
||||
_POSE_IMAGE,
|
||||
0,
|
||||
False,
|
||||
_get_expected_pose_landmarker_result(_POSE_LANDMARKS),
|
||||
),
|
||||
(
|
||||
_POSE_IMAGE,
|
||||
0,
|
||||
True,
|
||||
_get_expected_pose_landmarker_result(_POSE_LANDMARKS),
|
||||
),
|
||||
(_BURGER_IMAGE, 0, False, PoseLandmarkerResult([], [])),
|
||||
)
|
||||
def test_detect_for_video(
|
||||
self, image_path, rotation, output_segmentation_masks, expected_result
|
||||
):
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(image_path)
|
||||
)
|
||||
# Set rotation parameters using ImageProcessingOptions.
|
||||
image_processing_options = _ImageProcessingOptions(
|
||||
rotation_degrees=rotation
|
||||
)
|
||||
options = _PoseLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
output_segmentation_masks=output_segmentation_masks,
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _PoseLandmarker.create_from_options(options) as landmarker:
|
||||
for timestamp in range(0, 300, 30):
|
||||
result = landmarker.detect_for_video(
|
||||
test_image, timestamp, image_processing_options
|
||||
)
|
||||
if result.pose_landmarks:
|
||||
self._expect_pose_landmarker_results_correct(
|
||||
result,
|
||||
expected_result,
|
||||
output_segmentation_masks,
|
||||
_LANDMARKS_MARGIN,
|
||||
)
|
||||
else:
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_calling_detect_in_live_stream_mode(self):
|
||||
options = _PoseLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _PoseLandmarker.create_from_options(options) as landmarker:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the image mode'
|
||||
):
|
||||
landmarker.detect(self.test_image)
|
||||
|
||||
def test_calling_detect_for_video_in_live_stream_mode(self):
|
||||
options = _PoseLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _PoseLandmarker.create_from_options(options) as landmarker:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the video mode'
|
||||
):
|
||||
landmarker.detect_for_video(self.test_image, 0)
|
||||
|
||||
def test_detect_async_calls_with_illegal_timestamp(self):
|
||||
options = _PoseLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _PoseLandmarker.create_from_options(options) as landmarker:
|
||||
landmarker.detect_async(self.test_image, 100)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'
|
||||
):
|
||||
landmarker.detect_async(self.test_image, 0)
|
||||
|
||||
@parameterized.parameters(
|
||||
(
|
||||
_POSE_IMAGE,
|
||||
0,
|
||||
False,
|
||||
_get_expected_pose_landmarker_result(_POSE_LANDMARKS),
|
||||
),
|
||||
(
|
||||
_POSE_IMAGE,
|
||||
0,
|
||||
True,
|
||||
_get_expected_pose_landmarker_result(_POSE_LANDMARKS),
|
||||
),
|
||||
(_BURGER_IMAGE, 0, False, PoseLandmarkerResult([], [])),
|
||||
)
|
||||
def test_detect_async_calls(
|
||||
self, image_path, rotation, output_segmentation_masks, expected_result
|
||||
):
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(image_path)
|
||||
)
|
||||
# Set rotation parameters using ImageProcessingOptions.
|
||||
image_processing_options = _ImageProcessingOptions(
|
||||
rotation_degrees=rotation
|
||||
)
|
||||
observed_timestamp_ms = -1
|
||||
|
||||
def check_result(
|
||||
result: PoseLandmarkerResult, output_image: _Image, timestamp_ms: int
|
||||
):
|
||||
if result.pose_landmarks:
|
||||
self._expect_pose_landmarker_results_correct(
|
||||
result,
|
||||
expected_result,
|
||||
output_segmentation_masks,
|
||||
_LANDMARKS_MARGIN,
|
||||
)
|
||||
else:
|
||||
self.assertEqual(result, expected_result)
|
||||
self.assertTrue(
|
||||
np.array_equal(output_image.numpy_view(), test_image.numpy_view())
|
||||
)
|
||||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||
self.observed_timestamp_ms = timestamp_ms
|
||||
|
||||
options = _PoseLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
output_segmentation_masks=output_segmentation_masks,
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=check_result,
|
||||
)
|
||||
with _PoseLandmarker.create_from_options(options) as landmarker:
|
||||
for timestamp in range(0, 300, 30):
|
||||
landmarker.detect_async(test_image, timestamp, image_processing_options)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
Reference in New Issue
Block a user