diff --git a/tracetools_test/test/test_node.py b/tracetools_test/test/test_node.py index 67ef276..146abed 100644 --- a/tracetools_test/test/test_node.py +++ b/tracetools_test/test/test_node.py @@ -17,6 +17,9 @@ import unittest from tracetools_test.case import TraceTestCase +VERSION_REGEX = r'^[0-9]\.[0-9]\.[0-9]$' + + class TestNode(TraceTestCase): def __init__(self, *args) -> None: @@ -30,8 +33,31 @@ class TestNode(TraceTestCase): nodes=['test_publisher'] ) - def test_creation(self): - pass + def test_all(self): + # Check events order as set + self.assertEventsOrderSet(self._events_ros) + + # Check fields + rcl_init_events = self.get_events_with_name('ros2:rcl_init') + for event in rcl_init_events: + self.assertValidHandle(event, 'context_handle') + # TODO actually compare to version fetched from the tracetools package? + version_field = self.get_field(event, 'version') + self.assertRegex(version_field, VERSION_REGEX, 'invalid version number') + + rcl_node_init_events = self.get_events_with_name('ros2:rcl_node_init') + for event in rcl_node_init_events: + self.assertValidHandle(event, 'node_handle') + self.assertValidHandle(event, 'rmw_handle') + self.assertStringFieldNotEmpty(event, 'node_name') + self.assertStringFieldNotEmpty(event, 'namespace') + + # Check that the launched nodes have a corresponding rcl_node_init event + node_name_fields = [self.get_field(e, 'node_name') for e in rcl_node_init_events] + for node_name in self._nodes: + self.assertTrue( + node_name in node_name_fields, + f'cannot find node_init event for node name: {node_name} ({node_name_fields})') if __name__ == '__main__': diff --git a/tracetools_test/tracetools_test/case.py b/tracetools_test/tracetools_test/case.py index bff6d72..54fb7b7 100644 --- a/tracetools_test/tracetools_test/case.py +++ b/tracetools_test/tracetools_test/case.py @@ -14,11 +14,18 @@ """Module for a tracing-specific unittest.TestCase extension.""" +import time +from typing import Any +from typing import Dict from typing import List -from typing import Set import unittest -from .utils import cleanup_trace +# from .utils import cleanup_trace +from .utils import DictEvent +from .utils import get_event_name +from .utils import get_event_names +from .utils import get_event_timestamp +from .utils import get_field from .utils import get_trace_events from .utils import run_and_trace @@ -53,6 +60,9 @@ class TraceTestCase(unittest.TestCase): self._nodes = nodes def setUp(self): + # Get timestamp before trace (ns) + timestamp_before = int(time.time() * 1000000000.0) + exit_code, full_path = run_and_trace( self._base_path, self._session_name_prefix, @@ -61,7 +71,7 @@ class TraceTestCase(unittest.TestCase): self._package, self._nodes) - print(f'trace directory: {full_path}') + print(f'TRACE DIRECTORY: {full_path}') self._exit_code = exit_code self._full_path = full_path @@ -70,20 +80,90 @@ class TraceTestCase(unittest.TestCase): # Read events once self._events = get_trace_events(self._full_path) - self._event_names = self._get_event_names() + self._event_names = get_event_names(self._events) + self.assertGreater(len(self._events), 0, 'no events found in trace') + + # Check the timestamp of the first event + self.assertEventAfterTimestamp(self._events[0], timestamp_before) # Check that the enabled events are present ros = set(self._events_ros) if self._events_ros is not None else set() kernel = set(self._events_kernel) if self._events_kernel is not None else set() all_event_names = ros | kernel - self.assertSetEqual(all_event_names, self._event_names) + self.assertSetEqual(all_event_names, set(self._event_names)) + + # Check that the launched nodes are present as processes + self.assertProcessNamesExist(self._nodes) def tearDown(self): - cleanup_trace(self._full_path) + pass + # cleanup_trace(self._full_path) - def _get_event_names(self) -> Set[str]: - """Get a set of names of the events in the trace.""" - events_names = set() - for event in self._events: - events_names.add(event.name) - return events_names + def assertEventsOrderSet(self, event_names: List[str]): + """ + Compare given event names to trace events names as sets. + + :param event_names: the list of event names to compare to (as a set) + """ + self.assertSetEqual(set(self._event_names), set(event_names), 'wrong events order') + + def assertProcessNamesExist(self, names: List[str]): + """ + Check that the given processes exist. + + :param names: the node names to look for + """ + procnames = [e['procname'] for e in self._events] + for name in names: + # Procnames have a max length of 15 + name_trimmed = name[:15] + self.assertTrue(name_trimmed in procnames, 'node name not found in tracepoints') + + def assertValidHandle(self, event: DictEvent, handle_field_name: str): + """ + Check that the handle associated to a field name is valid. + + :param event: the event which has a handle field + :param handle_field_name: the field name of the handle to check + """ + handle_field = self.get_field(event, handle_field_name) + print(f'handle_field: {handle_field}') + self.assertGreater(handle_field, 0, f'invalid handle: {handle_field_name}') + + def assertStringFieldNotEmpty(self, event: DictEvent, string_field_name: str): + """ + Check that a string field is not empty. + + :param event: the event which has a string field + :param string_field_name: the field name of the string field + """ + string_field = self.get_field(event, string_field_name) + self.assertGreater(len(string_field), 0, 'empty string') + + def assertEventAfterTimestamp(self, event: DictEvent, timestamp: int): + self.assertGreater(get_event_timestamp(event), timestamp, 'event not after timestamp') + + def get_field(self, event: DictEvent, field_name: str) -> Any: + """ + Get field value; will fail test if not found. + + :param event: the event from which to get the value + :param field_name: the field name + :return: the value + """ + try: + value = get_field(event, field_name, default=None, raise_if_not_found=True) + except AttributeError as e: + # Explicitly failing here + self.fail(str(e)) + else: + return value + + def get_events_with_name(self, event_name: str) -> List[DictEvent]: + """ + Get all events with the given name. + + :param event_name: the event name + :return: the list of events with the given name + """ + return [e for e in self._events if get_event_name(e) == event_name] diff --git a/tracetools_test/tracetools_test/utils.py b/tracetools_test/tracetools_test/utils.py index f0a5303..0a83b88 100644 --- a/tracetools_test/tracetools_test/utils.py +++ b/tracetools_test/tracetools_test/utils.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Utils for tracetools_test.""" +"""Utils for tracetools_test that are not strictly test-related.""" import os import shutil import time +from typing import Any from typing import Dict from typing import List from typing import Tuple @@ -87,8 +88,9 @@ def cleanup_trace(full_path: str) -> None: """ shutil.rmtree(full_path) +DictEvent = Dict[str, Any] -def get_trace_events(trace_directory: str) -> List[Dict[str, str]]: +def get_trace_events(trace_directory: str) -> List[DictEvent]: """ Get the events of a trace. @@ -98,4 +100,51 @@ def get_trace_events(trace_directory: str) -> List[Dict[str, str]]: tc = babeltrace.TraceCollection() tc.add_traces_recursive(trace_directory, 'ctf') - return tc.events + return [_event_to_dict(event) for event in tc.events] + + +# List of ignored CTF fields +_IGNORED_FIELDS = [ + 'content_size', 'cpu_id', 'events_discarded', 'id', 'packet_size', 'packet_seq_num', + 'stream_id', 'stream_instance_id', 'timestamp_end', 'timestamp_begin', 'magic', 'uuid', 'v' +] +_DISCARD = 'events_discarded' + + +def _event_to_dict(event: babeltrace.babeltrace.Event) -> DictEvent: + """ + Convert name, timestamp, and all other keys except those in IGNORED_FIELDS into a dictionary. + + :param event: the event to convert + :return: the event as a dictionary + """ + d = {'_name': event.name, '_timestamp': event.timestamp} + if hasattr(event, _DISCARD) and event[_DISCARD] > 0: + print(event[_DISCARD]) + for key in [key for key in event.keys() if key not in _IGNORED_FIELDS]: + d[key] = event[key] + return d + + +def get_event_names(events: List[DictEvent]) -> List[str]: + """ + Get a list of names of the events in the trace. + + :param events: the events of the trace + :return: the list of event names + """ + return [get_event_name(e) for e in events] + + +def get_field(event: DictEvent, field_name: str, default=None, raise_if_not_found=True) -> Any: + field_value = event.get(field_name, default) + # If enabled, raise exception as soon as possible to avoid headaches + if raise_if_not_found and field_value is None: + raise AttributeError(f'event field "{field_name}" not found!') + return field_value + +def get_event_name(event: DictEvent) -> str: + return event['_name'] + +def get_event_timestamp(event: DictEvent) -> int: + return event['_timestamp']