diff --git a/tracetools_analysis/test/tracetools_analysis/test_dependency_solver.py b/tracetools_analysis/test/tracetools_analysis/test_dependency_solver.py index d5fff84..2bf86c0 100644 --- a/tracetools_analysis/test/tracetools_analysis/test_dependency_solver.py +++ b/tracetools_analysis/test/tracetools_analysis/test_dependency_solver.py @@ -19,7 +19,9 @@ from tracetools_analysis.processor import DependencySolver class DepEmtpy(Dependant): - pass + + def __init__(self, **kwargs) -> None: + self.myparam = kwargs.get('myparam', None) class DepOne(Dependant): @@ -102,6 +104,14 @@ class TestDependencySolver(unittest.TestCase): self.assertIsInstance(solution[2], DepOne2) self.assertIs(solution[3], deptwo_instance) + def test_kwargs(self) -> None: + depone_instance = DepOne() + + # Pass parameter and check that the new instance has it + solution = DependencySolver([depone_instance], myparam='myvalue').solve() + self.assertEqual(len(solution), 2, 'solution length invalid') + self.assertEqual(solution[0].myparam, 'myvalue', 'parameter not passed on') + if __name__ == '__main__': unittest.main() diff --git a/tracetools_analysis/tracetools_analysis/processor/__init__.py b/tracetools_analysis/tracetools_analysis/processor/__init__.py index 015763a..d38aa45 100644 --- a/tracetools_analysis/tracetools_analysis/processor/__init__.py +++ b/tracetools_analysis/tracetools_analysis/processor/__init__.py @@ -15,6 +15,7 @@ """Base processor module.""" from collections import defaultdict +from typing import Any from typing import Callable from typing import Dict from typing import List @@ -156,13 +157,16 @@ class DependencySolver(): def __init__( self, initial_dependants: List[Dependant], + **kwargs, ) -> None: """ Constructor. :param initial_dependants: the initial dependant instances, in order + :param kwargs: the parameters to pass on to new instances """ self._initial_deps = initial_dependants + self._kwargs = kwargs def solve(self) -> List[Dependant]: """ @@ -191,7 +195,7 @@ class DependencySolver(): ) -> None: if type(dep_instance) not in visited: for dependency_type in type(dep_instance).dependencies(): - DependencySolver.__solve_type( + self.__solve_type( dependency_type, visited, initial_map, @@ -200,8 +204,8 @@ class DependencySolver(): solution.append(dep_instance) visited.add(type(dep_instance)) - @staticmethod def __solve_type( + self, dep_type: Type[Dependant], visited: Set[Type[Dependant]], initial_map: Dict[Type[Dependant], Dependant], @@ -209,7 +213,7 @@ class DependencySolver(): ) -> None: if dep_type not in visited: for dependency_type in dep_type.dependencies(): - DependencySolver.__solve_type( + self.__solve_type( dependency_type, visited, initial_map, @@ -220,7 +224,7 @@ class DependencySolver(): if dep_type in initial_map: new_instance = initial_map.get(dep_type) else: - new_instance = dep_type() + new_instance = dep_type(**self._kwargs) solution.append(new_instance) visited.add(dep_type) @@ -260,8 +264,7 @@ class Processor(): :param handlers: the list of primary `EventHandler`s """ - # TODO pass on **kwargs - return DependencySolver(handlers).solve() + return DependencySolver(handlers, **kwargs).solve() def _get_handler_maps(self) -> HandlerMultimap: """