diff --git a/tracetools_test/test/test_intra.py b/tracetools_test/test/test_intra.py index b74a302..e8f18ea 100644 --- a/tracetools_test/test/test_intra.py +++ b/tracetools_test/test/test_intra.py @@ -57,83 +57,70 @@ class TestIntra(TraceTestCase): # Get subscription handles for normal & intra subscriptions sub_init_normal_event = sub_init_normal_events[0] - sub_init_intra_event = sub_init_intra_events[0] - sub_handle_normal = self.get_field(sub_init_normal_event, 'subscription_handle') - sub_handle_intra = self.get_field(sub_init_intra_event, 'subscription_handle') + # Note: sub handle linked to "normal" topic is the one actually linked to intra callback + sub_handle_intra = self.get_field(sub_init_normal_event, 'subscription_handle') + print(f'sub_handle_intra: {sub_handle_intra}') - # Get corresponding callback handles - callback_added_events = self.get_events_with_name( - 'ros2:rclcpp_subscription_callback_added') - callback_added_events_normal = self.get_events_with_field_value( - 'subscription_handle', - sub_handle_normal, - callback_added_events) - callback_added_events_intra = self.get_events_with_field_value( + # Get corresponding callback handle + # Callback handle + callback_added_events = self.get_events_with_field_value( 'subscription_handle', sub_handle_intra, - callback_added_events) + self.get_events_with_name( + 'ros2:rclcpp_subscription_callback_added')) self.assertEqual( - len(callback_added_events_normal), + len(callback_added_events), 1, - 'none or more than 1 callback added event for normal sub') - self.assertEqual( - len(callback_added_events_intra), - 1, - 'none or more than 1 callback added event for intra sub') - callback_added_event_normal = callback_added_events_normal[0] - callback_added_event_intra = callback_added_events_intra[0] - callback_handle_normal = self.get_field(callback_added_event_normal, 'callback') - callback_handle_intra = self.get_field(callback_added_event_intra, 'callback') + 'none or more than 1 callback added event') + callback_added_event = callback_added_events[0] + callback_handle_intra = self.get_field(callback_added_event, 'callback') # Get corresponding callback start/end pairs start_events = self.get_events_with_name('ros2:callback_start') end_events = self.get_events_with_name('ros2:callback_end') + # Should still have two start:end pairs (1 normal + 1 intra) self.assertEqual(len(start_events), 2, 'does not have 2 callback start events') self.assertEqual(len(end_events), 2, 'does not have 2 callback end events') - start_events_normal = self.get_events_with_field_value( - 'callback', - callback_handle_normal, - start_events) start_events_intra = self.get_events_with_field_value( 'callback', callback_handle_intra, start_events) - end_events_normal = self.get_events_with_field_value( - 'callback', - callback_handle_normal, - end_events) end_events_intra = self.get_events_with_field_value( 'callback', callback_handle_intra, end_events) - self.assertEqual( - len(start_events_normal), - 1, - 'none or more than one normal start event') self.assertEqual( len(start_events_intra), 1, 'none or more than one intra start event') - self.assertEqual( - len(end_events_normal), - 1, - 'none or more than one normal end event') self.assertEqual( len(end_events_intra), 1, 'none or more than one intra end event') - start_event_normal = start_events_normal[0] + + # Check is_intra_process field value start_event_intra = start_events_intra[0] - is_intra_value_normal = self.get_field(start_event_normal, 'is_intra_process') is_intra_value_intra = self.get_field(start_event_intra, 'is_intra_process') - self.assertEqual( - is_intra_value_normal, - 0, - 'is_intra_process field value not valid for normal sub') self.assertEqual( is_intra_value_intra, 1, - 'is_intra_process field value not valid for intra sub') + 'is_intra_process field value not valid for intra callback') + + # Also check that the other callback_start event (normal one) has the right field value + start_events_not_intra = self.get_events_with_field_not_value( + 'callback', + callback_handle_intra, + start_events) + self.assertEqual( + len(start_events_not_intra), + 1, + 'none or more than one normal start event') + start_event_not_intra = start_events_not_intra[0] + is_intra_value_not_intra = self.get_field(start_event_not_intra, 'is_intra_process') + self.assertEqual( + is_intra_value_not_intra, + 0, + 'is_intra_process field value not valid for normal callback') if __name__ == '__main__': diff --git a/tracetools_test/tracetools_test/case.py b/tracetools_test/tracetools_test/case.py index a087040..ef6f0c3 100644 --- a/tracetools_test/tracetools_test/case.py +++ b/tracetools_test/tracetools_test/case.py @@ -287,6 +287,24 @@ class TraceTestCase(unittest.TestCase): events = self._events return [e for e in events if get_field(e, field_name, None) == field_value] + def get_events_with_field_not_value( + self, + field_name: str, + field_value: Any, + events: List[DictEvent] = None + ) -> List[DictEvent]: + """ + Get all events with the given field but not the value. + + :param field_name: the name of the field to check + :param field_value: the value of the field to check + :param events: the events to check (or None to check all events) + :return: the events with the given field:value pair + """ + if events is None: + events = self._events + return [e for e in events if get_field(e, field_name, None) != field_value] + def are_events_ordered(self, first_event: DictEvent, second_event: DictEvent): """ Check that the first event was generated before the second event.