diff --git a/tracetools_test/tracetools_test/case.py b/tracetools_test/tracetools_test/case.py index 375e92f..49d6722 100644 --- a/tracetools_test/tracetools_test/case.py +++ b/tracetools_test/tracetools_test/case.py @@ -48,7 +48,7 @@ class TraceTestCase(unittest.TestCase): events_ros: List[str], nodes: List[str], base_path: str = '/tmp', - events_kernel: List[str] = None, + events_kernel: List[str] = [], package: str = 'tracetools_test', ) -> None: """Create a TraceTestCase.""" @@ -70,7 +70,8 @@ class TraceTestCase(unittest.TestCase): self._events_ros, self._events_kernel, self._package, - self._nodes) + self._nodes, + ) print(f'TRACE DIRECTORY: {full_path}') self._exit_code = exit_code diff --git a/tracetools_test/tracetools_test/utils.py b/tracetools_test/tracetools_test/utils.py index 4e713b9..ea70478 100644 --- a/tracetools_test/tracetools_test/utils.py +++ b/tracetools_test/tracetools_test/utils.py @@ -53,13 +53,15 @@ def run_and_trace( launch_actions = [] # Add trace action - launch_actions.append(Trace( - session_name=session_name, - append_timestamp=False, - base_path=base_path, - events_ust=ros_events, - events_kernel=kernel_events - )) + launch_actions.append( + Trace( + session_name=session_name, + append_timestamp=False, + base_path=base_path, + events_ust=ros_events, + events_kernel=kernel_events, + ) + ) # Add nodes for node_name in node_names: n = Node( diff --git a/tracetools_trace/tracetools_trace/tools/lttng_impl.py b/tracetools_trace/tracetools_trace/tools/lttng_impl.py index 7c593ce..578c4a6 100644 --- a/tracetools_trace/tracetools_trace/tools/lttng_impl.py +++ b/tracetools_trace/tracetools_trace/tools/lttng_impl.py @@ -18,6 +18,7 @@ from distutils.version import StrictVersion import re from typing import List from typing import Optional +from typing import Set from typing import Union import lttng @@ -50,9 +51,9 @@ def get_version() -> Union[StrictVersion, None]: def setup( session_name: str, base_path: str = DEFAULT_BASE_PATH, - ros_events: List[str] = DEFAULT_EVENTS_ROS, - kernel_events: List[str] = DEFAULT_EVENTS_KERNEL, - context_names: List[str] = DEFAULT_CONTEXT, + ros_events: Union[List[str], Set[str]] = DEFAULT_EVENTS_ROS, + kernel_events: Union[List[str], Set[str]] = DEFAULT_EVENTS_KERNEL, + context_names: Union[List[str], Set[str]] = DEFAULT_CONTEXT, channel_name_ust: str = 'ros2', channel_name_kernel: str = 'kchan', ) -> Optional[str]: @@ -70,6 +71,14 @@ def setup( :param channel_name_kernel: the kernel channel name :return: the full path to the trace directory """ + # Convert lists to sets + if not isinstance(ros_events, set): + ros_events = set(ros_events) + if not isinstance(kernel_events, set): + kernel_events = set(kernel_events) + if not isinstance(context_names, set): + context_names = set(context_names) + # Resolve full tracing directory path full_path = get_full_session_path(session_name, base_path=base_path) @@ -181,16 +190,16 @@ def destroy( def _create_events( - event_names_list: List[str], + event_names: Set[str], ) -> List[lttng.Event]: """ Create events list from names. - :param event_names_list: a list of names to create events for + :param event_names: a set of names to create events for :return: the list of events """ events_list = [] - for event_name in event_names_list: + for event_name in event_names: e = lttng.Event() e.name = event_name e.type = lttng.EVENT_TRACEPOINT @@ -290,16 +299,16 @@ def _context_name_to_type( def _create_context_list( - context_names_list: List[str], + context_names: Set[str], ) -> List[lttng.EventContext]: """ Create context list from names, and check for errors. - :param context_names_list: the list of context names + :param context_names: the set of context names :return: the event context list """ context_list = [] - for context_name in context_names_list: + for context_name in context_names: ec = lttng.EventContext() context_type = _context_name_to_type(context_name) if context_type is not None: