diff --git a/tracetools_test/src/test_service.cpp b/tracetools_test/src/test_service.cpp index f050319..aac018e 100644 --- a/tracetools_test/src/test_service.cpp +++ b/tracetools_test/src/test_service.cpp @@ -24,7 +24,7 @@ public: : Node("test_service", options) { srv_ = this->create_service( - "service", + "the_service", std::bind( &ServiceNode::service_callback, this, diff --git a/tracetools_test/test/test_node.py b/tracetools_test/test/test_node.py index 146abed..f2c4b59 100644 --- a/tracetools_test/test/test_node.py +++ b/tracetools_test/test/test_node.py @@ -34,7 +34,7 @@ class TestNode(TraceTestCase): ) def test_all(self): - # Check events order as set + # Check events order as set (e.g. init before node_init) self.assertEventsOrderSet(self._events_ros) # Check fields @@ -47,8 +47,7 @@ class TestNode(TraceTestCase): rcl_node_init_events = self.get_events_with_name('ros2:rcl_node_init') for event in rcl_node_init_events: - self.assertValidHandle(event, 'node_handle') - self.assertValidHandle(event, 'rmw_handle') + self.assertValidHandle(event, ['node_handle', 'rmw_handle']) self.assertStringFieldNotEmpty(event, 'node_name') self.assertStringFieldNotEmpty(event, 'namespace') diff --git a/tracetools_test/test/test_publisher.py b/tracetools_test/test/test_publisher.py index 739adf3..1abe87b 100644 --- a/tracetools_test/test/test_publisher.py +++ b/tracetools_test/test/test_publisher.py @@ -24,13 +24,42 @@ class TestPublisher(TraceTestCase): *args, session_name_prefix='session-test-publisher-creation', events_ros=[ + 'ros2:rcl_node_init', 'ros2:rcl_publisher_init', ], nodes=['test_publisher'] ) - def test_creation(self): - pass + def test_all(self): + # Check events order as set (e.g. node_init before pub_init) + self.assertEventsOrderSet(self._events_ros) + + # Check fields + pub_init_events = self.get_events_with_name('ros2:rcl_publisher_init') + for event in pub_init_events: + self.assertValidHandle( + event, + ['publisher_handle', 'node_handle', 'rmw_publisher_handle']) + self.assertValidQueueDepth(event, 'queue_depth') + self.assertStringFieldNotEmpty(event, 'topic_name') + + # Check that the test topic name exists + test_pub_init_events = self.get_events_with_procname('test_publisher', pub_init_events) + event_topic_names = [self.get_field(e, 'topic_name') for e in test_pub_init_events] + self.assertTrue('/the_topic' in event_topic_names, 'cannot find test topic name') + + # Check that the node handle matches with the node_init event + node_init_events = self.get_events_with_name('ros2:rcl_node_init') + test_pub_node_init_events = self.get_events_with_procname( + 'test_publisher', + node_init_events) + self.assertEqual(len(test_pub_node_init_events), 1, 'none or more than 1 node_init event') + test_pub_node_init_event = test_pub_node_init_events[0] + self.assertMatchingField( + test_pub_node_init_event, + 'node_handle', + None, + test_pub_init_events) if __name__ == '__main__': diff --git a/tracetools_test/test/test_service.py b/tracetools_test/test/test_service.py index d3667f1..7767ea9 100644 --- a/tracetools_test/test/test_service.py +++ b/tracetools_test/test/test_service.py @@ -24,14 +24,58 @@ class TestService(TraceTestCase): *args, session_name_prefix='session-test-service-creation', events_ros=[ + 'ros2:rcl_node_init', 'ros2:rcl_service_init', 'ros2:rclcpp_service_callback_added', ], nodes=['test_service'] ) - def test_creation(self): - pass + def test_all(self): + # Check events order as set (e.g. service_init before callback_added) + self.assertEventsOrderSet(self._events_ros) + + # Check fields + srv_init_events = self.get_events_with_name('ros2:rcl_service_init') + for event in srv_init_events: + self.assertValidHandle(event, ['service_handle', 'node_handle', 'rmw_service_handle']) + self.assertStringFieldNotEmpty(event, 'service_name') + + callback_added_events = self.get_events_with_name('ros2:rclcpp_service_callback_added') + for event in callback_added_events: + self.assertValidHandle(event, ['service_handle', 'callback']) + + # Check that the test service name exists + test_srv_init_events = self.get_events_with_procname('test_service', srv_init_events) + event_service_names = self.get_events_with_field_value( + 'service_name', + '/the_service', + test_srv_init_events) + self.assertGreaterEqual( + len(event_service_names), + 1, + 'cannot find test service name') + + # Check that the node handle matches the node_init event + node_init_events = self.get_events_with_name('ros2:rcl_node_init') + test_srv_node_init_events = self.get_events_with_procname( + 'test_service', + node_init_events) + self.assertEqual(len(test_srv_node_init_events), 1, 'none or more than 1 node_init event') + test_srv_node_init_event = test_srv_node_init_events[0] + self.assertMatchingField( + test_srv_node_init_event, + 'node_handle', + 'ros2:rcl_service_init', + test_srv_init_events) + + # Check that the service handles match + test_event_srv_init = event_service_names[0] + self.assertMatchingField( + test_event_srv_init, + 'service_handle', + None, + callback_added_events) if __name__ == '__main__': diff --git a/tracetools_test/test/test_service_callback.py b/tracetools_test/test/test_service_callback.py index 2fa22bd..f8de93c 100644 --- a/tracetools_test/test/test_service_callback.py +++ b/tracetools_test/test/test_service_callback.py @@ -30,8 +30,32 @@ class TestServiceCallback(TraceTestCase): nodes=['test_service_ping', 'test_service_pong'] ) - def test_callback(self): - pass + def test_all(self): + # Check events order as set (e.g. start before end) + self.assertEventsOrderSet(self._events_ros) + + # Check fields + start_events = self.get_events_with_name('ros2:callback_start') + for event in start_events: + self.assertValidHandle(event, 'callback') + is_intra_process_value = self.get_field(event, 'is_intra_process') + self.assertIsInstance(is_intra_process_value, int, 'is_intra_process not int') + # Should not be 1 for services (yet) + self.assertEqual( + is_intra_process_value, + 0, + f'invalid value for is_intra_process: {is_intra_process_value}') + + end_events = self.get_events_with_name('ros2:callback_end') + for event in end_events: + self.assertValidHandle(event, 'callback') + + # Check that there is at least a start/end pair for each node + for node in self._nodes: + test_start_events = self.get_events_with_procname(node, start_events) + test_end_events = self.get_events_with_procname(node, end_events) + self.assertGreater(len(test_start_events), 0, f'no start_callback events for node: {node}') + self.assertGreater(len(test_end_events), 0, f'no end_callback events for node: {node}') if __name__ == '__main__': diff --git a/tracetools_test/test/test_subscription.py b/tracetools_test/test/test_subscription.py index fc5210f..1d85a99 100644 --- a/tracetools_test/test/test_subscription.py +++ b/tracetools_test/test/test_subscription.py @@ -24,14 +24,60 @@ class TestSubscription(TraceTestCase): *args, session_name_prefix='session-test-subscription-creation', events_ros=[ + 'ros2:rcl_node_init', 'ros2:rcl_subscription_init', 'ros2:rclcpp_subscription_callback_added', ], nodes=['test_subscription'] ) - def test_creation(self): - pass + def test_all(self): + # Check events order as set (e.g. sub_init before callback_added) + self.assertEventsOrderSet(self._events_ros) + + # Check fields + sub_init_events = self.get_events_with_name('ros2:rcl_subscription_init') + for event in sub_init_events: + self.assertValidHandle( + event, + ['subscription_handle', 'node_handle', 'rmw_subscription_handle']) + self.assertValidQueueDepth(event, 'queue_depth') + self.assertStringFieldNotEmpty(event, 'topic_name') + + callback_added_events = self.get_events_with_name('ros2:rclcpp_subscription_callback_added') + for event in callback_added_events: + self.assertValidHandle(event, ['subscription_handle', 'callback']) + + # Check that the test topic name exists + test_sub_init_events = self.get_events_with_field_value( + 'topic_name', + '/the_topic', + sub_init_events) + self.assertEqual(len(test_sub_init_events), 1, 'cannot find test topic name') + test_sub_init_event = test_sub_init_events[0] + + # Check that the node handle matches the node_init event + node_init_events = self.get_events_with_name('ros2:rcl_node_init') + test_sub_node_init_events = self.get_events_with_procname( + 'test_subscription', + node_init_events) + self.assertEqual( + len(test_sub_node_init_events), + 1, + 'none or more than 1 node_init event') + test_sub_node_init_event = test_sub_node_init_events[0] + self.assertMatchingField( + test_sub_node_init_event, + 'node_handle', + 'ros2:rcl_subscription_init', + sub_init_events) + + # Check that subscription handle matches with callback_added event + self.assertMatchingField( + test_sub_init_event, + 'subscription_handle', + None, + callback_added_events) if __name__ == '__main__': diff --git a/tracetools_test/test/test_subscription_callback.py b/tracetools_test/test/test_subscription_callback.py index 1fa37aa..1f057da 100644 --- a/tracetools_test/test/test_subscription_callback.py +++ b/tracetools_test/test/test_subscription_callback.py @@ -30,8 +30,42 @@ class TestSubscriptionCallback(TraceTestCase): nodes=['test_ping', 'test_pong'] ) - def test_callback(self): - pass + def test_all(self): + # Check events order as set (e.g. start before end) + self.assertEventsOrderSet(self._events_ros) + + # Check fields + start_events = self.get_events_with_name('ros2:callback_start') + for event in start_events: + self.assertValidHandle(event, 'callback') + is_intra_process_value = self.get_field(event, 'is_intra_process') + self.assertIsInstance(is_intra_process_value, int, 'is_intra_process not int') + self.assertTrue( + is_intra_process_value in [0, 1], + f'invalid value for is_intra_process: {is_intra_process_value}') + + end_events = self.get_events_with_name('ros2:callback_end') + for event in end_events: + self.assertValidHandle(event, 'callback') + + # Check that a start:end pair has a common callback handle + # Note: might be unstable if tracing is disabled too early + ping_events = self.get_events_with_procname('test_ping') + pong_events = self.get_events_with_procname('test_pong') + ping_events_start = self.get_events_with_name('ros2:callback_start', ping_events) + pong_events_start = self.get_events_with_name('ros2:callback_start', pong_events) + for ping_start in ping_events_start: + self.assertMatchingField( + ping_start, + 'callback', + 'ros2:callback_end', + ping_events) + for pong_start in pong_events_start: + self.assertMatchingField( + pong_start, + 'callback', + 'ros2:callback_end', + pong_events) if __name__ == '__main__': diff --git a/tracetools_test/test/test_timer.py b/tracetools_test/test/test_timer.py index ea83a07..a86f4d3 100644 --- a/tracetools_test/test/test_timer.py +++ b/tracetools_test/test/test_timer.py @@ -33,7 +33,58 @@ class TestTimer(TraceTestCase): ) def test_all(self): - pass + # Check events order as set (e.g. init, callback added, start, end) + self.assertEventsOrderSet(self._events_ros) + + # Check fields + init_events = self.get_events_with_name('ros2:rcl_timer_init') + for event in init_events: + self.assertValidHandle(event, 'timer_handle') + period_value = self.get_field(event, 'period') + self.assertIsInstance(period_value, int) + self.assertGreaterEqual(period_value, 0, f'invalid period value: {period_value}') + + callback_added_events = self.get_events_with_name('ros2:rclcpp_timer_callback_added') + for event in callback_added_events: + self.assertValidHandle(event, ['timer_handle', 'callback']) + + start_events = self.get_events_with_name('ros2:callback_start') + for event in start_events: + self.assertValidHandle(event, 'callback') + is_intra_process_value = self.get_field(event, 'is_intra_process') + self.assertIsInstance(is_intra_process_value, int, 'is_intra_process not int') + # Should not be 1 for timer + self.assertEqual( + is_intra_process_value, + 0, + f'invalid value for is_intra_process: {is_intra_process_value}') + + end_events = self.get_events_with_name('ros2:callback_end') + for event in end_events: + self.assertValidHandle(event, 'callback') + + # Find and check given timer period + test_timer_init_event = self.get_events_with_procname('test_timer', init_events) + self.assertEqual(len(test_timer_init_event), 1) + test_init_event = test_timer_init_event[0] + test_period = self.get_field(test_init_event, 'period') + self.assertIsInstance(test_period, int) + self.assertEqual(test_period, 1000000, f'invalid period: {test_period}') + + # Check that the timer_init:callback_added pair exists and has a common timer handle + self.assertMatchingField( + test_init_event, + 'timer_handle', + 'ros2:rclcpp_timer_callback_added', + callback_added_events) + + # Check that a callback start:end pair has a common callback handle + for start_event in start_events: + self.assertMatchingField( + start_event, + 'callback', + None, + end_events) if __name__ == '__main__': diff --git a/tracetools_test/tracetools_test/case.py b/tracetools_test/tracetools_test/case.py index 6ff92e5..947d85a 100644 --- a/tracetools_test/tracetools_test/case.py +++ b/tracetools_test/tracetools_test/case.py @@ -18,6 +18,7 @@ import time from typing import Any from typing import Dict from typing import List +from typing import Union import unittest # from .utils import cleanup_trace @@ -26,6 +27,7 @@ from .utils import get_event_name from .utils import get_event_names from .utils import get_event_timestamp from .utils import get_field +from .utils import get_procname from .utils import get_trace_events from .utils import run_and_trace @@ -113,21 +115,30 @@ class TraceTestCase(unittest.TestCase): :param names: the node names to look for """ - procnames = [e['procname'] for e in self._events] + procnames = [get_procname(e) for e in self._events] for name in names: # Procnames have a max length of 15 name_trimmed = name[:15] self.assertTrue(name_trimmed in procnames, 'node name not found in tracepoints') - def assertValidHandle(self, event: DictEvent, handle_field_name: str): + def assertValidHandle(self, event: DictEvent, handle_field_name: Union[str, List[str]]): """ - Check that the handle associated to a field name is valid. + Check that the handle associated with a field name is valid. :param event: the event which has a handle field - :param handle_field_name: the field name of the handle to check + :param handle_field_name: the field name(s) of the handle to check """ - handle_field = self.get_field(event, handle_field_name) - self.assertGreater(handle_field, 0, f'invalid handle: {handle_field_name}') + is_list = isinstance(handle_field_name, list) + handle_field_names = handle_field_name if is_list else [handle_field_name] + for field_name in handle_field_names: + handle_value = self.get_field(event, field_name) + self.assertIsInstance(handle_value, int, 'handle value not int') + self.assertGreater(handle_value, 0, f'invalid handle value: {field_name}') + + def assertValidQueueDepth(self, event: DictEvent, queue_depth_field_name: str = 'queue_depth'): + queue_depth_value = self.get_field(event, 'queue_depth') + self.assertIsInstance(queue_depth_value, int, 'invalid queue depth type') + self.assertGreater(queue_depth_value, 0, 'invalid queue depth') def assertStringFieldNotEmpty(self, event: DictEvent, string_field_name: str): """ @@ -142,6 +153,67 @@ class TraceTestCase(unittest.TestCase): def assertEventAfterTimestamp(self, event: DictEvent, timestamp: int): self.assertGreater(get_event_timestamp(event), timestamp, 'event not after timestamp') + def assertEventOrder(self, first_event: DictEvent, second_event: DictEvent): + """ + Check that the first event was generated before the second event. + + :param first_event: the first event + :param second_event: the second event + """ + self.assertTrue(self.are_events_ordered(first_event, second_event)) + + def assertMatchingField( + self, + initial_event: DictEvent, + field_name: str, + matching_event_name: str = None, + events: List[DictEvent] = None + ): + """ + Check that the value of a field for a given event has a matching event that follows. + + :param initial_event: the first event, which is the origin of the common field value + :param field_name: the name of the common field to check + :param matching_event_name: the name of the event to check (or None to check all) + :param events: the events to check (or None to check all in trace) + """ + if events is None: + events = self._events + if matching_event_name is not None: + events = self.get_events_with_name(matching_event_name, events) + field_value = self.get_field(initial_event, field_name) + + # Get events with that handle + matches = self.get_events_with_field_value( + field_name, + field_value, + events) + # Check that there is at least one + self.assertGreaterEqual( + len(matches), + 1, + f'no corresponding {field_name}') + # Check order + # Since matching pairs might repeat, we need to check + # that there is at least one match that comes after + matches_ordered = [e for e in matches if self.are_events_ordered(initial_event, e)] + self.assertGreaterEqual( + len(matches_ordered), + 1, + 'matching field event not after initial event') + + + def assertFieldEquals(self, event: DictEvent, field_name: str, value: Any): + """ + Check the value of a field. + + :param event: the event + :param field_name: the name of the field to check + :param value: to value to compare the field value to + """ + actual_value = self.get_field(event, field_name) + self.assertEqual(actual_value, value, 'invalid field value') + def get_field(self, event: DictEvent, field_name: str) -> Any: """ Get field value; will fail test if not found. @@ -157,12 +229,73 @@ class TraceTestCase(unittest.TestCase): self.fail(str(e)) else: return value + + def get_procname(self, event: DictEvent) -> str: + """ + Get procname. - def get_events_with_name(self, event_name: str) -> List[DictEvent]: + :param event: the event + :return: the procname of the event + """ + return get_procname(event) + + def get_events_with_name( + self, + event_name: str, + events: List[DictEvent] = None + ) -> List[DictEvent]: """ Get all events with the given name. :param event_name: the event name + :param events: the events to check (or None to check all events) :return: the list of events with the given name """ - return [e for e in self._events if get_event_name(e) == event_name] + if events is None: + events = self._events + return [e for e in events if get_event_name(e) == event_name] + + def get_events_with_procname( + self, + procname: str, + events: List[DictEvent] = None + ) -> List[DictEvent]: + """ + Get all events with the given procname. + + :param procname: the procname + :param events: the events to check (or None to check all events) + :return: the events with the given procname + """ + if events is None: + events = self._events + return [e for e in events if self.get_procname(e) == procname[:15]] + + def get_events_with_field_value( + self, + field_name: str, + field_value: Any, + events: List[DictEvent] = None + ) -> List[DictEvent]: + """ + Get all events with the given field:value + + :param field_name: the name of the field to check + :param field_value: the value of the field to check + :param events: the events to check (or None to check all events) + :return: the events with the given field:value pair + """ + if events is None: + events = self._events + return [e for e in events if get_field(e, field_name, None) == field_value] + + def are_events_ordered(self, first_event: DictEvent, second_event: DictEvent): + """ + Check that the first event was generated before the second event. + + :param first_event: the first event + :param second_event: the second event + """ + first_timestamp = get_event_timestamp(first_event) + second_timestamp = get_event_timestamp(second_event) + return first_timestamp < second_timestamp diff --git a/tracetools_test/tracetools_test/utils.py b/tracetools_test/tracetools_test/utils.py index 0a83b88..776ea9b 100644 --- a/tracetools_test/tracetools_test/utils.py +++ b/tracetools_test/tracetools_test/utils.py @@ -143,8 +143,14 @@ def get_field(event: DictEvent, field_name: str, default=None, raise_if_not_foun raise AttributeError(f'event field "{field_name}" not found!') return field_value + def get_event_name(event: DictEvent) -> str: return event['_name'] + def get_event_timestamp(event: DictEvent) -> int: return event['_timestamp'] + + +def get_procname(event: DictEvent) -> str: + return event['procname']