diff --git a/tracetools_analysis/test/tracetools_analysis/test_processor.py b/tracetools_analysis/test/tracetools_analysis/test_processor.py index 2798f73..af7f6c9 100644 --- a/tracetools_analysis/test/tracetools_analysis/test_processor.py +++ b/tracetools_analysis/test/tracetools_analysis/test_processor.py @@ -113,6 +113,11 @@ class TestProcessor(unittest.TestCase): *args, ) + def test_event_handler_process(self) -> None: + # Should not be called directly + with self.assertRaises(AssertionError): + EventHandler.process([]) + def test_handler_wrong_signature(self) -> None: handler = WrongHandler() mock_event = { @@ -159,6 +164,13 @@ class TestProcessor(unittest.TestCase): # Passes check Processor(EventHandlerWithRequiredEvent()).process([required_mock_event, mock_event]) + def test_get_handler_by_type(self) -> None: + handler1 = StubHandler1() + handler2 = StubHandler2() + processor = Processor(handler1, handler2) + result = processor.get_handler_by_type(StubHandler1) + self.assertTrue(result is handler1) + if __name__ == '__main__': unittest.main() diff --git a/tracetools_analysis/tracetools_analysis/processor/__init__.py b/tracetools_analysis/tracetools_analysis/processor/__init__.py index 4605132..a0fedda 100644 --- a/tracetools_analysis/tracetools_analysis/processor/__init__.py +++ b/tracetools_analysis/tracetools_analysis/processor/__init__.py @@ -21,6 +21,7 @@ from typing import Dict from typing import List from typing import Set from typing import Type +from typing import Union from tracetools_read import DictEvent from tracetools_read import get_event_name @@ -125,7 +126,7 @@ class EventHandler(Dependant): f'empty map: {self.__class__.__name__}' assert all(required_name in handler_map.keys() for required_name in self.required_events()) self._handler_map = handler_map - self.processor = None + self._processor = None @property def handler_map(self) -> HandlerMap: @@ -137,6 +138,10 @@ class EventHandler(Dependant): """Get the data model.""" return None + @property + def processor(self) -> 'Processor': + return self._processor + @staticmethod def required_events() -> Set[str]: """ @@ -147,9 +152,12 @@ class EventHandler(Dependant): """ return {} - def register_processor(self, processor: 'Processor') -> None: + def register_processor( + self, + processor: 'Processor', + ) -> None: """Register processor with this `EventHandler` so that it can query other handlers.""" - self.processor = processor + self._processor = processor @staticmethod def int_to_hex_str(addr: int) -> str: @@ -157,14 +165,19 @@ class EventHandler(Dependant): return f'0x{addr:X}' @classmethod - def process(cls, events: List[DictEvent], **kwargs) -> 'EventHandler': + def process( + cls, + events: List[DictEvent], + **kwargs, + ) -> 'EventHandler': """ Create a `Processor` and process an instance of the class. :param events: the list of events :return: the processor object after processing """ - assert cls != EventHandler, 'only call process() from inheriting classes' + if cls == EventHandler: + raise AssertionError('only call EventHandler.process() from inheriting classes') handler_object = cls(**kwargs) # pylint: disable=all processor = Processor(handler_object, **kwargs) processor.process(events) @@ -324,6 +337,21 @@ class Processor(): for handler in handlers: handler.register_processor(self) + def get_handler_by_type( + self, + handler_type: Type, + ) -> Union[EventHandler, None]: + """ + Get an existing EventHandler instance from its type. + + :param handler_type: the type of EventHandler subclass to find + :return: the EventHandler instance if found, otherwise `None` + """ + return next( + (handler for handler in self._expanded_handlers if type(handler) is handler_type), + None, + ) + @staticmethod def get_event_names( events: List[DictEvent],