diff --git a/tracetools_test/test/test_intra.py b/tracetools_test/test/test_intra.py index 3528bcd..f53f919 100644 --- a/tracetools_test/test/test_intra.py +++ b/tracetools_test/test/test_intra.py @@ -45,12 +45,12 @@ class TestIntra(TraceTestCase): 'topic_name', '/the_topic/_intra', sub_init_events) - self.assertEqual( - len(sub_init_normal_events), + self.assertNumEvents( + sub_init_normal_events, 1, 'none or more than 1 sub init event for normal sub') - self.assertEqual( - len(sub_init_intra_events), + self.assertNumEvents( + sub_init_intra_events, 1, 'none or more than 1 sub init event for intra sub') @@ -67,8 +67,8 @@ class TestIntra(TraceTestCase): sub_handle_intra, self.get_events_with_name( 'ros2:rclcpp_subscription_callback_added')) - self.assertEqual( - len(callback_added_events), + self.assertNumEvents( + callback_added_events, 1, 'none or more than 1 callback added event') callback_added_event = callback_added_events[0] @@ -78,8 +78,8 @@ class TestIntra(TraceTestCase): start_events = self.get_events_with_name('ros2:callback_start') end_events = self.get_events_with_name('ros2:callback_end') # Should still have two start:end pairs (1 normal + 1 intra) - self.assertEqual(len(start_events), 2, 'does not have 2 callback start events') - self.assertEqual(len(end_events), 2, 'does not have 2 callback end events') + self.assertNumEvents(start_events, 2, 'does not have 2 callback start events') + self.assertNumEvents(end_events, 2, 'does not have 2 callback end events') start_events_intra = self.get_events_with_field_value( 'callback', callback_handle_intra, @@ -88,12 +88,12 @@ class TestIntra(TraceTestCase): 'callback', callback_handle_intra, end_events) - self.assertEqual( - len(start_events_intra), + self.assertNumEvents( + start_events_intra, 1, 'none or more than one intra start event') - self.assertEqual( - len(end_events_intra), + self.assertNumEvents( + end_events_intra, 1, 'none or more than one intra end event') @@ -110,8 +110,8 @@ class TestIntra(TraceTestCase): 'callback', callback_handle_intra, start_events) - self.assertEqual( - len(start_events_not_intra), + self.assertNumEvents( + start_events_not_intra, 1, 'none or more than one normal start event') start_event_not_intra = start_events_not_intra[0] diff --git a/tracetools_test/test/test_publisher.py b/tracetools_test/test/test_publisher.py index 163568f..eb0369d 100644 --- a/tracetools_test/test/test_publisher.py +++ b/tracetools_test/test/test_publisher.py @@ -48,16 +48,16 @@ class TestPublisher(TraceTestCase): 'topic_name', '/the_topic', test_pub_init_events) - self.assertEqual( - len(test_pub_init_topic_events), + self.assertNumEvents( + test_pub_init_topic_events, 1, 'none or more than 1 pub_init even for test topic') # Check queue_depth value test_pub_init_topic_event = test_pub_init_topic_events[0] - test_queue_depth = self.get_field(test_pub_init_topic_event, 'queue_depth') - self.assertEqual( - test_queue_depth, + self.assertFieldEquals( + test_pub_init_topic_event, + 'queue_depth', 10, 'pub_init event does not have expected queue depth value') @@ -66,7 +66,7 @@ class TestPublisher(TraceTestCase): test_pub_node_init_events = self.get_events_with_procname( 'test_publisher', node_init_events) - self.assertEqual(len(test_pub_node_init_events), 1, 'none or more than 1 node_init event') + self.assertNumEvents(test_pub_node_init_events, 1, 'none or more than 1 node_init event') test_pub_node_init_event = test_pub_node_init_events[0] self.assertMatchingField( test_pub_node_init_event, diff --git a/tracetools_test/test/test_service.py b/tracetools_test/test/test_service.py index 8a6754a..c170025 100644 --- a/tracetools_test/test/test_service.py +++ b/tracetools_test/test/test_service.py @@ -60,7 +60,7 @@ class TestService(TraceTestCase): test_srv_node_init_events = self.get_events_with_procname( 'test_service', node_init_events) - self.assertEqual(len(test_srv_node_init_events), 1, 'none or more than 1 node_init event') + self.assertNumEvents(test_srv_node_init_events, 1, 'none or more than 1 node_init event') test_srv_node_init_event = test_srv_node_init_events[0] self.assertMatchingField( test_srv_node_init_event, diff --git a/tracetools_test/test/test_service_callback.py b/tracetools_test/test/test_service_callback.py index 5bc6571..1c4507d 100644 --- a/tracetools_test/test/test_service_callback.py +++ b/tracetools_test/test/test_service_callback.py @@ -37,13 +37,12 @@ class TestServiceCallback(TraceTestCase): start_events = self.get_events_with_name('ros2:callback_start') for event in start_events: self.assertValidHandle(event, 'callback') - is_intra_process_value = self.get_field(event, 'is_intra_process') - self.assertIsInstance(is_intra_process_value, int, 'is_intra_process not int') # Should not be 1 for services (yet) - self.assertEqual( - is_intra_process_value, + self.assertFieldEquals( + event, + 'is_intra_process', 0, - f'invalid value for is_intra_process: {is_intra_process_value}') + 'invalid value for is_intra_process') end_events = self.get_events_with_name('ros2:callback_end') for event in end_events: diff --git a/tracetools_test/test/test_subscription.py b/tracetools_test/test/test_subscription.py index 2c13ae9..c54e59f 100644 --- a/tracetools_test/test/test_subscription.py +++ b/tracetools_test/test/test_subscription.py @@ -53,13 +53,13 @@ class TestSubscription(TraceTestCase): 'topic_name', '/the_topic', sub_init_events) - self.assertEqual(len(test_sub_init_events), 1, 'cannot find test topic name') + self.assertNumEvents(test_sub_init_events, 1, 'cannot find test topic name') test_sub_init_event = test_sub_init_events[0] # Check queue_depth value - test_queue_depth = self.get_field(test_sub_init_event, 'queue_depth') - self.assertEqual( - test_queue_depth, + self.assertFieldEquals( + test_sub_init_event, + 'queue_depth', 10, 'sub_init event does not have expected queue depth value') @@ -68,8 +68,8 @@ class TestSubscription(TraceTestCase): test_sub_node_init_events = self.get_events_with_procname( 'test_subscription', node_init_events) - self.assertEqual( - len(test_sub_node_init_events), + self.assertNumEvents( + test_sub_node_init_events, 1, 'none or more than 1 node_init event') test_sub_node_init_event = test_sub_node_init_events[0] diff --git a/tracetools_test/test/test_timer.py b/tracetools_test/test/test_timer.py index cc339fc..74c6f69 100644 --- a/tracetools_test/test/test_timer.py +++ b/tracetools_test/test/test_timer.py @@ -50,13 +50,12 @@ class TestTimer(TraceTestCase): start_events = self.get_events_with_name('ros2:callback_start') for event in start_events: self.assertValidHandle(event, 'callback') - is_intra_process_value = self.get_field(event, 'is_intra_process') - self.assertIsInstance(is_intra_process_value, int, 'is_intra_process not int') # Should not be 1 for timer self.assertEqual( - is_intra_process_value, + event, + 'is_intra_process', 0, - f'invalid value for is_intra_process: {is_intra_process_value}') + 'invalid value for is_intra_process') end_events = self.get_events_with_name('ros2:callback_end') for event in end_events: @@ -66,9 +65,7 @@ class TestTimer(TraceTestCase): test_timer_init_event = self.get_events_with_procname('test_timer', init_events) self.assertEqual(len(test_timer_init_event), 1) test_init_event = test_timer_init_event[0] - test_period = self.get_field(test_init_event, 'period') - self.assertIsInstance(test_period, int) - self.assertEqual(test_period, 1000000, f'invalid period: {test_period}') + self.assertFieldEquals(test_init_event, 'period', 1000000, 'invalid period') # Check that the timer_init:callback_added pair exists and has a common timer handle self.assertMatchingField( diff --git a/tracetools_test/tracetools_test/case.py b/tracetools_test/tracetools_test/case.py index e7620da..3e53c49 100644 --- a/tracetools_test/tracetools_test/case.py +++ b/tracetools_test/tracetools_test/case.py @@ -134,6 +134,12 @@ class TraceTestCase(unittest.TestCase): self.assertGreater(handle_value, 0, f'invalid handle value: {field_name}') def assertValidQueueDepth(self, event: DictEvent, queue_depth_field_name: str = 'queue_depth'): + """ + Check that the queue depth value is valid. + + :param event: the event with the queue depth field + :param queue_depth_field_name: the field name for queue depth + """ queue_depth_value = self.get_field(event, 'queue_depth') self.assertIsInstance(queue_depth_value, int, 'invalid queue depth type') self.assertGreater(queue_depth_value, 0, 'invalid queue depth') @@ -149,6 +155,12 @@ class TraceTestCase(unittest.TestCase): self.assertGreater(len(string_field), 0, 'empty string') def assertEventAfterTimestamp(self, event: DictEvent, timestamp: int): + """ + Check that the event happens after the given timestamp. + + :param event: the event to check + :param timestamp: the reference timestamp + """ self.assertGreater(get_event_timestamp(event), timestamp, 'event not after timestamp') def assertEventOrder(self, first_event: DictEvent, second_event: DictEvent): @@ -160,6 +172,21 @@ class TraceTestCase(unittest.TestCase): """ self.assertTrue(self.are_events_ordered(first_event, second_event)) + def assertNumEvents( + self, + events: List[DictEvent], + expected_number: int, + msg: str = 'wrong number of events' + ): + """ + Check number of events. + + :param events: the events to check + :param expected_number: the expected number of events + :param msg: the message to display on failure + """ + self.assertEqual(len(events), expected_number, msg) + def assertMatchingField( self, initial_event: DictEvent, @@ -200,16 +227,23 @@ class TraceTestCase(unittest.TestCase): 1, 'matching field event not after initial event') - def assertFieldEquals(self, event: DictEvent, field_name: str, value: Any): + def assertFieldEquals( + self, + event: DictEvent, + field_name: str, + value: Any, + msg: str = 'wrong field value' + ): """ Check the value of a field. :param event: the event :param field_name: the name of the field to check :param value: to value to compare the field value to + :param msg: the message to display on failure """ actual_value = self.get_field(event, field_name) - self.assertEqual(actual_value, value, 'invalid field value') + self.assertEqual(actual_value, value, msg) def get_field(self, event: DictEvent, field_name: str) -> Any: """