From fc1ce31504ed11f565e7352bec37d6ff0e24adcf Mon Sep 17 00:00:00 2001 From: Christophe Bedard Date: Thu, 14 Nov 2019 16:16:55 -0800 Subject: [PATCH] Make get_events_with_field_*value take a single field value or a list --- tracetools_test/tracetools_test/case.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tracetools_test/tracetools_test/case.py b/tracetools_test/tracetools_test/case.py index 353edb3..a530ddf 100644 --- a/tracetools_test/tracetools_test/case.py +++ b/tracetools_test/tracetools_test/case.py @@ -320,38 +320,42 @@ class TraceTestCase(unittest.TestCase): def get_events_with_field_value( self, field_name: str, - field_value: Any, + field_values: Any, events: List[DictEvent] = None ) -> List[DictEvent]: """ Get all events with the given field:value. :param field_name: the name of the field to check - :param field_value: the value of the field to check + :param field_values: the value(s) of the field to check :param events: the events to check (or `None` to check all events) :return: the events with the given field:value pair """ + if not isinstance(field_values, list): + field_values = [field_values] if events is None: events = self._events - return [e for e in events if get_field(e, field_name, None) == field_value] + return [e for e in events if get_field(e, field_name, None) in field_values] def get_events_with_field_not_value( self, field_name: str, - field_value: Any, + field_values: Any, events: List[DictEvent] = None ) -> List[DictEvent]: """ Get all events with the given field but not the value. :param field_name: the name of the field to check - :param field_value: the value of the field to check + :param field_values: the value(s) of the field to check :param events: the events to check (or `None` to check all events) :return: the events with the given field:value pair """ + if not isinstance(field_values, list): + field_values = [field_values] if events is None: events = self._events - return [e for e in events if get_field(e, field_name, None) != field_value] + return [e for e in events if get_field(e, field_name, None) not in field_values] def are_events_ordered(self, first_event: DictEvent, second_event: DictEvent): """