diff --git a/tracetools_test/test/test_node.py b/tracetools_test/test/test_node.py index a3d0c9f..67ef276 100644 --- a/tracetools_test/test/test_node.py +++ b/tracetools_test/test/test_node.py @@ -14,40 +14,24 @@ import unittest -from tracetools_test.utils import ( - cleanup_trace, - get_trace_event_names, - run_and_trace, -) - -BASE_PATH = '/tmp' -PKG = 'tracetools_test' -node_creation_events = [ - 'ros2:rcl_init', - 'ros2:rcl_node_init', -] +from tracetools_test.case import TraceTestCase -class TestNode(unittest.TestCase): +class TestNode(TraceTestCase): + + def __init__(self, *args) -> None: + super().__init__( + *args, + session_name_prefix='session-test-node-creation', + events_ros=[ + 'ros2:rcl_init', + 'ros2:rcl_node_init', + ], + nodes=['test_publisher'] + ) def test_creation(self): - session_name_prefix = 'session-test-node-creation' - test_node = ['test_publisher'] - - exit_code, full_path = run_and_trace( - BASE_PATH, - session_name_prefix, - node_creation_events, - None, - PKG, - test_node) - self.assertEqual(exit_code, 0) - - trace_events = get_trace_event_names(full_path) - print(f'trace_events: {trace_events}') - self.assertSetEqual(set(node_creation_events), trace_events) - - cleanup_trace(full_path) + pass if __name__ == '__main__': diff --git a/tracetools_test/test/test_publisher.py b/tracetools_test/test/test_publisher.py index decbbd2..739adf3 100644 --- a/tracetools_test/test/test_publisher.py +++ b/tracetools_test/test/test_publisher.py @@ -14,39 +14,23 @@ import unittest -from tracetools_test.utils import ( - cleanup_trace, - get_trace_event_names, - run_and_trace, -) - -BASE_PATH = '/tmp' -PKG = 'tracetools_test' -publisher_creation_events = [ - 'ros2:rcl_publisher_init', -] +from tracetools_test.case import TraceTestCase -class TestPublisher(unittest.TestCase): +class TestPublisher(TraceTestCase): + + def __init__(self, *args) -> None: + super().__init__( + *args, + session_name_prefix='session-test-publisher-creation', + events_ros=[ + 'ros2:rcl_publisher_init', + ], + nodes=['test_publisher'] + ) def test_creation(self): - session_name_prefix = 'session-test-publisher-creation' - test_node = ['test_publisher'] - - exit_code, full_path = run_and_trace( - BASE_PATH, - session_name_prefix, - publisher_creation_events, - None, - PKG, - test_node) - self.assertEqual(exit_code, 0) - - trace_events = get_trace_event_names(full_path) - print(f'trace_events: {trace_events}') - self.assertSetEqual(set(publisher_creation_events), trace_events) - - cleanup_trace(full_path) + pass if __name__ == '__main__': diff --git a/tracetools_test/test/test_service.py b/tracetools_test/test/test_service.py index 7bcf697..d3667f1 100644 --- a/tracetools_test/test/test_service.py +++ b/tracetools_test/test/test_service.py @@ -14,40 +14,24 @@ import unittest -from tracetools_test.utils import ( - cleanup_trace, - get_trace_event_names, - run_and_trace, -) - -BASE_PATH = '/tmp' -PKG = 'tracetools_test' -service_creation_events = [ - 'ros2:rcl_service_init', - 'ros2:rclcpp_service_callback_added', -] +from tracetools_test.case import TraceTestCase -class TestService(unittest.TestCase): +class TestService(TraceTestCase): + + def __init__(self, *args) -> None: + super().__init__( + *args, + session_name_prefix='session-test-service-creation', + events_ros=[ + 'ros2:rcl_service_init', + 'ros2:rclcpp_service_callback_added', + ], + nodes=['test_service'] + ) def test_creation(self): - session_name_prefix = 'session-test-service-creation' - test_nodes = ['test_service'] - - exit_code, full_path = run_and_trace( - BASE_PATH, - session_name_prefix, - service_creation_events, - None, - PKG, - test_nodes) - self.assertEqual(exit_code, 0) - - trace_events = get_trace_event_names(full_path) - print(f'trace_events: {trace_events}') - self.assertSetEqual(set(service_creation_events), trace_events) - - cleanup_trace(full_path) + pass if __name__ == '__main__': diff --git a/tracetools_test/test/test_service_callback.py b/tracetools_test/test/test_service_callback.py index 8b6a1e0..2fa22bd 100644 --- a/tracetools_test/test/test_service_callback.py +++ b/tracetools_test/test/test_service_callback.py @@ -14,40 +14,24 @@ import unittest -from tracetools_test.utils import ( - cleanup_trace, - get_trace_event_names, - run_and_trace, -) - -BASE_PATH = '/tmp' -PKG = 'tracetools_test' -service_callback_events = [ - 'ros2:callback_start', - 'ros2:callback_end', -] +from tracetools_test.case import TraceTestCase -class TestServiceCallback(unittest.TestCase): +class TestServiceCallback(TraceTestCase): + + def __init__(self, *args) -> None: + super().__init__( + *args, + session_name_prefix='session-test-service-callback', + events_ros=[ + 'ros2:callback_start', + 'ros2:callback_end', + ], + nodes=['test_service_ping', 'test_service_pong'] + ) def test_callback(self): - session_name_prefix = 'session-test-service-callback' - test_nodes = ['test_service_ping', 'test_service_pong'] - - exit_code, full_path = run_and_trace( - BASE_PATH, - session_name_prefix, - service_callback_events, - None, - PKG, - test_nodes) - self.assertEqual(exit_code, 0) - - trace_events = get_trace_event_names(full_path) - print(f'trace_events: {trace_events}') - self.assertSetEqual(set(service_callback_events), trace_events) - - cleanup_trace(full_path) + pass if __name__ == '__main__': diff --git a/tracetools_test/test/test_subscription.py b/tracetools_test/test/test_subscription.py index 1562934..fc5210f 100644 --- a/tracetools_test/test/test_subscription.py +++ b/tracetools_test/test/test_subscription.py @@ -14,40 +14,24 @@ import unittest -from tracetools_test.utils import ( - cleanup_trace, - get_trace_event_names, - run_and_trace, -) - -BASE_PATH = '/tmp' -PKG = 'tracetools_test' -subscription_creation_events = [ - 'ros2:rcl_subscription_init', - 'ros2:rclcpp_subscription_callback_added', -] +from tracetools_test.case import TraceTestCase -class TestSubscription(unittest.TestCase): +class TestSubscription(TraceTestCase): + + def __init__(self, *args) -> None: + super().__init__( + *args, + session_name_prefix='session-test-subscription-creation', + events_ros=[ + 'ros2:rcl_subscription_init', + 'ros2:rclcpp_subscription_callback_added', + ], + nodes=['test_subscription'] + ) def test_creation(self): - session_name_prefix = 'session-test-subscription-creation' - test_node = ['test_subscription'] - - exit_code, full_path = run_and_trace( - BASE_PATH, - session_name_prefix, - subscription_creation_events, - None, - PKG, - test_node) - self.assertEqual(exit_code, 0) - - trace_events = get_trace_event_names(full_path) - print(f'trace_events: {trace_events}') - self.assertSetEqual(set(subscription_creation_events), trace_events) - - cleanup_trace(full_path) + pass if __name__ == '__main__': diff --git a/tracetools_test/test/test_subscription_callback.py b/tracetools_test/test/test_subscription_callback.py index 251eb4f..1fa37aa 100644 --- a/tracetools_test/test/test_subscription_callback.py +++ b/tracetools_test/test/test_subscription_callback.py @@ -14,40 +14,24 @@ import unittest -from tracetools_test.utils import ( - cleanup_trace, - get_trace_event_names, - run_and_trace, -) - -BASE_PATH = '/tmp' -PKG = 'tracetools_test' -subscription_callback_events = [ - 'ros2:callback_start', - 'ros2:callback_end', -] +from tracetools_test.case import TraceTestCase -class TestSubscriptionCallback(unittest.TestCase): +class TestSubscriptionCallback(TraceTestCase): + + def __init__(self, *args) -> None: + super().__init__( + *args, + session_name_prefix='session-test-subscription-callback', + events_ros=[ + 'ros2:callback_start', + 'ros2:callback_end', + ], + nodes=['test_ping', 'test_pong'] + ) def test_callback(self): - session_name_prefix = 'session-test-subscription-callback' - test_nodes = ['test_ping', 'test_pong'] - - exit_code, full_path = run_and_trace( - BASE_PATH, - session_name_prefix, - subscription_callback_events, - None, - PKG, - test_nodes) - self.assertEqual(exit_code, 0) - - trace_events = get_trace_event_names(full_path) - print(f'trace_events: {trace_events}') - self.assertSetEqual(set(subscription_callback_events), trace_events) - - cleanup_trace(full_path) + pass if __name__ == '__main__': diff --git a/tracetools_test/test/test_timer.py b/tracetools_test/test/test_timer.py index 9b1b3e8..ea83a07 100644 --- a/tracetools_test/test/test_timer.py +++ b/tracetools_test/test/test_timer.py @@ -14,42 +14,26 @@ import unittest -from tracetools_test.utils import ( - cleanup_trace, - get_trace_event_names, - run_and_trace, -) - -BASE_PATH = '/tmp' -PKG = 'tracetools_test' -timer_events = [ - 'ros2:rcl_timer_init', - 'ros2:rclcpp_timer_callback_added', - 'ros2:callback_start', - 'ros2:callback_end', -] +from tracetools_test.case import TraceTestCase -class TestTimer(unittest.TestCase): +class TestTimer(TraceTestCase): + + def __init__(self, *args) -> None: + super().__init__( + *args, + session_name_prefix='session-test-timer-all', + events_ros=[ + 'ros2:rcl_timer_init', + 'ros2:rclcpp_timer_callback_added', + 'ros2:callback_start', + 'ros2:callback_end', + ], + nodes=['test_timer'] + ) def test_all(self): - session_name_prefix = 'session-test-timer-all' - test_nodes = ['test_timer'] - - exit_code, full_path = run_and_trace( - BASE_PATH, - session_name_prefix, - timer_events, - None, - PKG, - test_nodes) - self.assertEqual(exit_code, 0) - - trace_events = get_trace_event_names(full_path) - print(f'trace_events: {trace_events}') - self.assertSetEqual(set(timer_events), trace_events) - - cleanup_trace(full_path) + pass if __name__ == '__main__': diff --git a/tracetools_test/tracetools_test/case.py b/tracetools_test/tracetools_test/case.py new file mode 100644 index 0000000..bff6d72 --- /dev/null +++ b/tracetools_test/tracetools_test/case.py @@ -0,0 +1,89 @@ +# Copyright 2019 Robert Bosch GmbH +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for a tracing-specific unittest.TestCase extension.""" + +from typing import List +from typing import Set +import unittest + +from .utils import cleanup_trace +from .utils import get_trace_events +from .utils import run_and_trace + + +class TraceTestCase(unittest.TestCase): + """ + TestCase extension for tests on a trace. + + Sets up tracing, traces given nodes, and provides + the resulting events for an extending class to test on. + It also does some basic checks on the resulting trace. + """ + + def __init__( + self, + *args, + session_name_prefix: str, + events_ros: List[str], + nodes: List[str], + base_path: str = '/tmp', + events_kernel: List[str] = None, + package: str = 'tracetools_test' + ) -> None: + """Constructor.""" + print(f'methodName={args[0]}') + super().__init__(methodName=args[0]) + self._base_path = base_path + self._session_name_prefix = session_name_prefix + self._events_ros = events_ros + self._events_kernel = events_kernel + self._package = package + self._nodes = nodes + + def setUp(self): + exit_code, full_path = run_and_trace( + self._base_path, + self._session_name_prefix, + self._events_ros, + self._events_kernel, + self._package, + self._nodes) + + print(f'trace directory: {full_path}') + self._exit_code = exit_code + self._full_path = full_path + + # Check that setUp() ran fine + self.assertEqual(self._exit_code, 0) + + # Read events once + self._events = get_trace_events(self._full_path) + self._event_names = self._get_event_names() + + # Check that the enabled events are present + ros = set(self._events_ros) if self._events_ros is not None else set() + kernel = set(self._events_kernel) if self._events_kernel is not None else set() + all_event_names = ros | kernel + self.assertSetEqual(all_event_names, self._event_names) + + def tearDown(self): + cleanup_trace(self._full_path) + + def _get_event_names(self) -> Set[str]: + """Get a set of names of the events in the trace.""" + events_names = set() + for event in self._events: + events_names.add(event.name) + return events_names diff --git a/tracetools_test/tracetools_test/utils.py b/tracetools_test/tracetools_test/utils.py index 94f8e31..f0a5303 100644 --- a/tracetools_test/tracetools_test/utils.py +++ b/tracetools_test/tracetools_test/utils.py @@ -17,8 +17,8 @@ import os import shutil import time +from typing import Dict from typing import List -from typing import Set from typing import Tuple import babeltrace @@ -41,7 +41,7 @@ def run_and_trace( kernel_events: List[str], package_name: str, node_names: List[str] - ) -> Tuple[int, str]: +) -> Tuple[int, str]: """ Run a node while tracing. @@ -55,7 +55,6 @@ def run_and_trace( """ session_name = f'{session_name_prefix}-{time.strftime("%Y%m%d%H%M%S")}' full_path = os.path.join(base_path, session_name) - print(f'trace directory: {full_path}') lttng_setup(session_name, full_path, ros_events=ros_events, kernel_events=kernel_events) lttng_start(session_name) @@ -89,19 +88,14 @@ def cleanup_trace(full_path: str) -> None: shutil.rmtree(full_path) -def get_trace_event_names(trace_directory: str) -> Set[str]: +def get_trace_events(trace_directory: str) -> List[Dict[str, str]]: """ - Get a set of event names in a trace. + Get the events of a trace. :param trace_directory: the path to the main/top trace directory - :return: event names + :return: events """ tc = babeltrace.TraceCollection() tc.add_traces_recursive(trace_directory, 'ctf') - event_names = set() - - for event in tc.events: - event_names.add(event.name) - - return event_names + return tc.events