hand
This commit is contained in:
@@ -0,0 +1,396 @@
|
||||
# Copyright 2020 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for mediapipe.python.solution_base."""
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from google.protobuf import text_format
|
||||
from mediapipe.framework import calculator_pb2
|
||||
from mediapipe.framework.formats import detection_pb2
|
||||
from mediapipe.python import solution_base
|
||||
from mediapipe.python.solution_base import PacketDataType
|
||||
|
||||
CALCULATOR_OPTIONS_TEST_GRAPH_CONFIG = """
|
||||
input_stream: 'image_in'
|
||||
output_stream: 'image_out'
|
||||
node {
|
||||
name: 'ImageTransformation'
|
||||
calculator: 'ImageTransformationCalculator'
|
||||
input_stream: 'IMAGE:image_in'
|
||||
output_stream: 'IMAGE:image_out'
|
||||
options: {
|
||||
[mediapipe.ImageTransformationCalculatorOptions.ext] {
|
||||
output_width: 10
|
||||
output_height: 10
|
||||
}
|
||||
}
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] {
|
||||
output_width: 10
|
||||
output_height: 10
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
class SolutionBaseTest(parameterized.TestCase):
|
||||
|
||||
def test_invalid_initialization_arguments(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Must provide exactly one of \'binary_graph_path\' or \'graph_config\'.'
|
||||
):
|
||||
solution_base.SolutionBase()
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Must provide exactly one of \'binary_graph_path\' or \'graph_config\'.'
|
||||
):
|
||||
solution_base.SolutionBase(
|
||||
graph_config=calculator_pb2.CalculatorGraphConfig(),
|
||||
binary_graph_path='/tmp/no_such.binarypb')
|
||||
|
||||
@parameterized.named_parameters(('no_graph_input_output_stream', """
|
||||
node {
|
||||
calculator: 'PassThroughCalculator'
|
||||
input_stream: 'in'
|
||||
output_stream: 'out'
|
||||
}
|
||||
""", RuntimeError, 'does not have a corresponding output stream.'),
|
||||
('calcualtor_io_mismatch', """
|
||||
node {
|
||||
calculator: 'PassThroughCalculator'
|
||||
input_stream: 'in'
|
||||
input_stream: 'in2'
|
||||
output_stream: 'out'
|
||||
}
|
||||
""", ValueError, 'must use matching tags and indexes.'),
|
||||
('unkown_registered_stream_type_name', """
|
||||
input_stream: 'in'
|
||||
output_stream: 'out'
|
||||
node {
|
||||
calculator: 'PassThroughCalculator'
|
||||
input_stream: 'in'
|
||||
output_stream: 'out'
|
||||
}
|
||||
""", RuntimeError, 'Unable to find the type for stream \"in\".'))
|
||||
def test_invalid_config(self, text_config, error_type, error_message):
|
||||
config_proto = text_format.Parse(text_config,
|
||||
calculator_pb2.CalculatorGraphConfig())
|
||||
with self.assertRaisesRegex(error_type, error_message):
|
||||
solution_base.SolutionBase(graph_config=config_proto)
|
||||
|
||||
def test_valid_input_data_type_proto(self):
|
||||
text_config = """
|
||||
input_stream: 'input_detections'
|
||||
output_stream: 'output_detections'
|
||||
node {
|
||||
calculator: 'DetectionUniqueIdCalculator'
|
||||
input_stream: 'DETECTION_LIST:input_detections'
|
||||
output_stream: 'DETECTION_LIST:output_detections'
|
||||
}
|
||||
"""
|
||||
config_proto = text_format.Parse(text_config,
|
||||
calculator_pb2.CalculatorGraphConfig())
|
||||
with solution_base.SolutionBase(graph_config=config_proto) as solution:
|
||||
input_detections = detection_pb2.DetectionList()
|
||||
detection_1 = input_detections.detection.add()
|
||||
text_format.Parse('score: 0.5', detection_1)
|
||||
detection_2 = input_detections.detection.add()
|
||||
text_format.Parse('score: 0.8', detection_2)
|
||||
results = solution.process({'input_detections': input_detections})
|
||||
self.assertTrue(hasattr(results, 'output_detections'))
|
||||
self.assertLen(results.output_detections.detection, 2)
|
||||
expected_detection_1 = detection_pb2.Detection()
|
||||
text_format.Parse('score: 0.5, detection_id: 1', expected_detection_1)
|
||||
expected_detection_2 = detection_pb2.Detection()
|
||||
text_format.Parse('score: 0.8, detection_id: 2', expected_detection_2)
|
||||
self.assertEqual(results.output_detections.detection[0],
|
||||
expected_detection_1)
|
||||
self.assertEqual(results.output_detections.detection[1],
|
||||
expected_detection_2)
|
||||
|
||||
def test_invalid_input_data_type_proto_vector(self):
|
||||
text_config = """
|
||||
input_stream: 'input_detections'
|
||||
output_stream: 'output_detections'
|
||||
node {
|
||||
calculator: 'DetectionUniqueIdCalculator'
|
||||
input_stream: 'DETECTIONS:input_detections'
|
||||
output_stream: 'DETECTIONS:output_detections'
|
||||
}
|
||||
"""
|
||||
config_proto = text_format.Parse(text_config,
|
||||
calculator_pb2.CalculatorGraphConfig())
|
||||
with solution_base.SolutionBase(graph_config=config_proto) as solution:
|
||||
detection = detection_pb2.Detection()
|
||||
text_format.Parse('score: 0.5', detection)
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
'SolutionBase can only process non-audio and non-proto-list data. '
|
||||
+ 'PROTO_LIST type is not supported.'
|
||||
):
|
||||
solution.process({'input_detections': detection})
|
||||
|
||||
def test_invalid_input_image_data(self):
|
||||
text_config = """
|
||||
input_stream: 'image_in'
|
||||
output_stream: 'image_out'
|
||||
node {
|
||||
calculator: 'ImageTransformationCalculator'
|
||||
input_stream: 'IMAGE:image_in'
|
||||
output_stream: 'IMAGE:transformed_image_in'
|
||||
}
|
||||
node {
|
||||
calculator: 'ImageTransformationCalculator'
|
||||
input_stream: 'IMAGE:transformed_image_in'
|
||||
output_stream: 'IMAGE:image_out'
|
||||
}
|
||||
"""
|
||||
config_proto = text_format.Parse(text_config,
|
||||
calculator_pb2.CalculatorGraphConfig())
|
||||
with solution_base.SolutionBase(graph_config=config_proto) as solution:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Input image must contain three channel rgb data.'):
|
||||
solution.process(np.arange(36, dtype=np.uint8).reshape(3, 3, 4))
|
||||
|
||||
@parameterized.named_parameters(('graph_without_side_packets', """
|
||||
input_stream: 'image_in'
|
||||
output_stream: 'image_out'
|
||||
node {
|
||||
calculator: 'ImageTransformationCalculator'
|
||||
input_stream: 'IMAGE:image_in'
|
||||
output_stream: 'IMAGE:transformed_image_in'
|
||||
}
|
||||
node {
|
||||
calculator: 'ImageTransformationCalculator'
|
||||
input_stream: 'IMAGE:transformed_image_in'
|
||||
output_stream: 'IMAGE:image_out'
|
||||
}
|
||||
""", None), ('graph_with_side_packets', """
|
||||
input_stream: 'image_in'
|
||||
input_side_packet: 'allow_signal'
|
||||
input_side_packet: 'rotation_degrees'
|
||||
output_stream: 'image_out'
|
||||
node {
|
||||
calculator: 'ImageTransformationCalculator'
|
||||
input_stream: 'IMAGE:image_in'
|
||||
input_side_packet: 'ROTATION_DEGREES:rotation_degrees'
|
||||
output_stream: 'IMAGE:transformed_image_in'
|
||||
}
|
||||
node {
|
||||
calculator: 'GateCalculator'
|
||||
input_stream: 'transformed_image_in'
|
||||
input_side_packet: 'ALLOW:allow_signal'
|
||||
output_stream: 'image_out_to_transform'
|
||||
}
|
||||
node {
|
||||
calculator: 'ImageTransformationCalculator'
|
||||
input_stream: 'IMAGE:image_out_to_transform'
|
||||
input_side_packet: 'ROTATION_DEGREES:rotation_degrees'
|
||||
output_stream: 'IMAGE:image_out'
|
||||
}""", {
|
||||
'allow_signal': True,
|
||||
'rotation_degrees': 0
|
||||
}))
|
||||
def test_solution_process(self, text_config, side_inputs):
|
||||
self._process_and_verify(
|
||||
config_proto=text_format.Parse(text_config,
|
||||
calculator_pb2.CalculatorGraphConfig()),
|
||||
side_inputs=side_inputs)
|
||||
|
||||
def test_invalid_calculator_options(self):
|
||||
text_config = """
|
||||
input_stream: 'image_in'
|
||||
output_stream: 'image_out'
|
||||
node {
|
||||
calculator: 'ImageTransformationCalculator'
|
||||
input_stream: 'IMAGE:image_in'
|
||||
output_stream: 'IMAGE:transformed_image_in'
|
||||
}
|
||||
node {
|
||||
name: 'SignalGate'
|
||||
calculator: 'GateCalculator'
|
||||
input_stream: 'transformed_image_in'
|
||||
input_side_packet: 'ALLOW:allow_signal'
|
||||
output_stream: 'image_out_to_transform'
|
||||
}
|
||||
node {
|
||||
calculator: 'ImageTransformationCalculator'
|
||||
input_stream: 'IMAGE:image_out_to_transform'
|
||||
output_stream: 'IMAGE:image_out'
|
||||
}
|
||||
"""
|
||||
config_proto = text_format.Parse(text_config,
|
||||
calculator_pb2.CalculatorGraphConfig())
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Modifying the calculator options of SignalGate is not supported.'):
|
||||
solution_base.SolutionBase(
|
||||
graph_config=config_proto,
|
||||
calculator_params={'SignalGate.invalid_field': 'I am invalid'})
|
||||
|
||||
def test_calculator_has_both_options_and_node_options(self):
|
||||
config_proto = text_format.Parse(CALCULATOR_OPTIONS_TEST_GRAPH_CONFIG,
|
||||
calculator_pb2.CalculatorGraphConfig())
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'has both options and node_options fields.'):
|
||||
solution_base.SolutionBase(
|
||||
graph_config=config_proto,
|
||||
calculator_params={
|
||||
'ImageTransformation.output_width': 0,
|
||||
'ImageTransformation.output_height': 0
|
||||
})
|
||||
|
||||
def test_modifying_calculator_proto2_options(self):
|
||||
config_proto = text_format.Parse(CALCULATOR_OPTIONS_TEST_GRAPH_CONFIG,
|
||||
calculator_pb2.CalculatorGraphConfig())
|
||||
# To test proto2 options only, remove the proto3 node_options field from the
|
||||
# graph config.
|
||||
self.assertEqual('ImageTransformation', config_proto.node[0].name)
|
||||
config_proto.node[0].ClearField('node_options')
|
||||
self._process_and_verify(
|
||||
config_proto=config_proto,
|
||||
calculator_params={
|
||||
'ImageTransformation.output_width': 0,
|
||||
'ImageTransformation.output_height': 0
|
||||
})
|
||||
|
||||
def test_modifying_calculator_proto3_node_options(self):
|
||||
config_proto = text_format.Parse(CALCULATOR_OPTIONS_TEST_GRAPH_CONFIG,
|
||||
calculator_pb2.CalculatorGraphConfig())
|
||||
# To test proto3 node options only, remove the proto2 options field from the
|
||||
# graph config.
|
||||
self.assertEqual('ImageTransformation', config_proto.node[0].name)
|
||||
config_proto.node[0].ClearField('options')
|
||||
self._process_and_verify(
|
||||
config_proto=config_proto,
|
||||
calculator_params={
|
||||
'ImageTransformation.output_width': 0,
|
||||
'ImageTransformation.output_height': 0
|
||||
})
|
||||
|
||||
def test_adding_calculator_options(self):
|
||||
config_proto = text_format.Parse(CALCULATOR_OPTIONS_TEST_GRAPH_CONFIG,
|
||||
calculator_pb2.CalculatorGraphConfig())
|
||||
# To test a calculator with no options field, remove both proto2 options and
|
||||
# proto3 node_options fields from the graph config.
|
||||
self.assertEqual('ImageTransformation', config_proto.node[0].name)
|
||||
config_proto.node[0].ClearField('options')
|
||||
config_proto.node[0].ClearField('node_options')
|
||||
self._process_and_verify(
|
||||
config_proto=config_proto,
|
||||
calculator_params={
|
||||
'ImageTransformation.output_width': 0,
|
||||
'ImageTransformation.output_height': 0
|
||||
})
|
||||
|
||||
@parameterized.named_parameters(('graph_without_side_packets', """
|
||||
input_stream: 'image_in'
|
||||
output_stream: 'image_out'
|
||||
node {
|
||||
calculator: 'ImageTransformationCalculator'
|
||||
input_stream: 'IMAGE:image_in'
|
||||
output_stream: 'IMAGE:transformed_image_in'
|
||||
}
|
||||
node {
|
||||
calculator: 'ImageTransformationCalculator'
|
||||
input_stream: 'IMAGE:transformed_image_in'
|
||||
output_stream: 'IMAGE:image_out'
|
||||
}
|
||||
""", None), ('graph_with_side_packets', """
|
||||
input_stream: 'image_in'
|
||||
input_side_packet: 'allow_signal'
|
||||
input_side_packet: 'rotation_degrees'
|
||||
output_stream: 'image_out'
|
||||
node {
|
||||
calculator: 'ImageTransformationCalculator'
|
||||
input_stream: 'IMAGE:image_in'
|
||||
input_side_packet: 'ROTATION_DEGREES:rotation_degrees'
|
||||
output_stream: 'IMAGE:transformed_image_in'
|
||||
}
|
||||
node {
|
||||
calculator: 'GateCalculator'
|
||||
input_stream: 'transformed_image_in'
|
||||
input_side_packet: 'ALLOW:allow_signal'
|
||||
output_stream: 'image_out_to_transform'
|
||||
}
|
||||
node {
|
||||
calculator: 'ImageTransformationCalculator'
|
||||
input_stream: 'IMAGE:image_out_to_transform'
|
||||
input_side_packet: 'ROTATION_DEGREES:rotation_degrees'
|
||||
output_stream: 'IMAGE:image_out'
|
||||
}""", {
|
||||
'allow_signal': True,
|
||||
'rotation_degrees': 0
|
||||
}))
|
||||
def test_solution_reset(self, text_config, side_inputs):
|
||||
config_proto = text_format.Parse(text_config,
|
||||
calculator_pb2.CalculatorGraphConfig())
|
||||
input_image = np.arange(27, dtype=np.uint8).reshape(3, 3, 3)
|
||||
with solution_base.SolutionBase(
|
||||
graph_config=config_proto, side_inputs=side_inputs) as solution:
|
||||
for _ in range(20):
|
||||
outputs = solution.process(input_image)
|
||||
self.assertTrue(np.array_equal(input_image, outputs.image_out))
|
||||
solution.reset()
|
||||
|
||||
def test_solution_stream_type_hints(self):
|
||||
text_config = """
|
||||
input_stream: 'union_type_image_in'
|
||||
output_stream: 'image_type_out'
|
||||
node {
|
||||
calculator: 'ToImageCalculator'
|
||||
input_stream: 'IMAGE:union_type_image_in'
|
||||
output_stream: 'IMAGE:image_type_out'
|
||||
}
|
||||
"""
|
||||
config_proto = text_format.Parse(text_config,
|
||||
calculator_pb2.CalculatorGraphConfig())
|
||||
input_image = np.arange(27, dtype=np.uint8).reshape(3, 3, 3)
|
||||
with solution_base.SolutionBase(
|
||||
graph_config=config_proto,
|
||||
stream_type_hints={'union_type_image_in': PacketDataType.IMAGE
|
||||
}) as solution:
|
||||
for _ in range(20):
|
||||
outputs = solution.process(input_image)
|
||||
self.assertTrue(np.array_equal(input_image, outputs.image_type_out))
|
||||
with solution_base.SolutionBase(
|
||||
graph_config=config_proto,
|
||||
stream_type_hints={'union_type_image_in': PacketDataType.IMAGE_FRAME
|
||||
}) as solution2:
|
||||
for _ in range(20):
|
||||
outputs = solution2.process(input_image)
|
||||
self.assertTrue(np.array_equal(input_image, outputs.image_type_out))
|
||||
|
||||
def _process_and_verify(self,
|
||||
config_proto,
|
||||
side_inputs=None,
|
||||
calculator_params=None):
|
||||
input_image = np.arange(27, dtype=np.uint8).reshape(3, 3, 3)
|
||||
with solution_base.SolutionBase(
|
||||
graph_config=config_proto,
|
||||
side_inputs=side_inputs,
|
||||
calculator_params=calculator_params) as solution:
|
||||
outputs = solution.process(input_image)
|
||||
outputs2 = solution.process({'image_in': input_image})
|
||||
self.assertTrue(np.array_equal(input_image, outputs.image_out))
|
||||
self.assertTrue(np.array_equal(input_image, outputs2.image_out))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
Reference in New Issue
Block a user