diff --git a/latency_graph/latency_graph.py b/latency_graph/latency_graph.py index 8ddae26..bc05506 100644 --- a/latency_graph/latency_graph.py +++ b/latency_graph/latency_graph.py @@ -1,58 +1,34 @@ +from bisect import bisect_left, bisect from dataclasses import dataclass from itertools import combinations from multiprocessing import Pool from typing import Optional, Set, List, Iterable, Dict, Tuple +import numpy as np from tqdm.notebook import tqdm from tqdm.contrib import concurrent from matching.subscriptions import sanitize from tracing_interop.tr_types import TrContext, TrCallbackObject, TrCallbackSymbol, TrNode, TrPublisher, TrSubscription, \ - TrTimer, TrPublishInstance, TrSubscriptionObject, TrTopic, TrCallbackInstance + TrTimer, TrPublishInstance, TrSubscriptionObject, TrTopic, TrCallbackInstance, Timestamp - -TOPIC_FILTERS = ["/parameter_events", "/tf_static", "/robot_description", "diagnostics"] - - -def _map_cb_times(args): - cb_id, inst_times, pub_timestamps = args - pub_cb_overlaps = {i: set() for i in range(len(pub_timestamps))} - - inst_times.sort(key=lambda tup: tup[0]) # tup[0] is start time - - inst_iter = iter(inst_times) - pub_iter = iter(enumerate(pub_timestamps)) - - inst_start, inst_end = next(inst_iter, (None, None)) - i, t = next(pub_iter, (None, None)) - while inst_start is not None and i is not None: - if inst_start <= t <= inst_end: - pub_cb_overlaps[i].add(cb_id) - - if t <= inst_end: - i, t = next(pub_iter, (None, None)) - else: - inst_start, inst_end = next(inst_iter, (None, None)) - - return pub_cb_overlaps +TOPIC_FILTERS = ["/parameter_events", "/tf_static", "/robot_description", "diagnostics", "/rosout"] def _get_cb_owner_node(cb: TrCallbackObject) -> TrNode | None: match cb.owner: - case TrTimer(nodes=nodes): - owner_nodes = nodes + case TrTimer(node=node): + owner_node = node case TrSubscriptionObject(subscription=sub): - owner_nodes = [sub.node] + owner_node = sub.node case _: - owner_nodes = [] + owner_node = None - if len(owner_nodes) > 1: - raise RuntimeError(f"CB has owners {', '.join(map(lambda n: n.path, owner_nodes))}") - elif not owner_nodes: + if not owner_node: print("[WARN] CB has no owners") return None - return owner_nodes[0] + return owner_node def _hierarchize(lg_nodes: Iterable['LGHierarchyLevel']): @@ -79,33 +55,25 @@ def _hierarchize(lg_nodes: Iterable['LGHierarchyLevel']): def inst_runtime_interval(cb_inst: TrCallbackInstance): - inst_t_min = cb_inst.timestamp.timestamp() - inst_t_max = inst_t_min + cb_inst.duration.total_seconds() - return inst_t_min, inst_t_max + start_time = cb_inst.timestamp + end_time = start_time + cb_inst.duration + return start_time, end_time def _get_publishing_cbs(cbs: Set[TrCallbackObject], pub: TrPublisher): """ Counts number of publication instances that lie within one of the cb_intervals. """ - pub_timestamps = [inst.timestamp * 1e-9 for inst in pub.instances] + pub_insts = pub.instances + pub_cb_overlaps = {i: set() for i in range(len(pub_insts))} - # Algorithm: Two-pointer method - # With both the pub_timestamps and cb_intervals sorted ascending, - # we can cut down the O(m*n) comparisons to O(m+n). - pub_timestamps.sort() - - cb_id_to_cb = {cb.id: cb for cb in cbs} - _map_args = [(cb.id, [inst_runtime_interval(inst) for inst in cb.callback_instances], pub_timestamps) for cb in cbs] - - with Pool() as p: - cb_wise_overlaps = p.map(_map_cb_times, _map_args) - - pub_cb_overlaps = {i: set() for i in range(len(pub_timestamps))} - for overlap_dict in cb_wise_overlaps: - for i, cb_ids in overlap_dict.items(): - cbs = [cb_id_to_cb[cb_id] for cb_id in cb_ids] - pub_cb_overlaps[i].update(cbs) + for cb in cbs: + cb_intervals = map(inst_runtime_interval, cb.callback_instances) + for t_start, t_end in cb_intervals: + i_overlap_begin = bisect_left(pub_insts, t_start, key=lambda x: x.timestamp) + i_overlap_end = bisect(pub_insts, t_end, key=lambda x: x.timestamp) + for i in range(i_overlap_begin, i_overlap_end): + pub_cb_overlaps[i].add(cb) pub_cbs = set() cb_cb_overlaps = set() @@ -168,21 +136,24 @@ def _get_cb_topic_deps(nodes_to_cbs: Dict[TrNode, Set[TrCallbackObject]]): # For topics published to during the runtime of the callback's instances, # assume that they are published by the callback cbs_publishing_topic: Dict[TrTopic, Set[TrCallbackObject]] = {} - p = tqdm(desc="Processing node publications", total=len(nodes_to_cbs)) - for node, cbs in nodes_to_cbs.items(): - p.update() + cb_publishers: Dict[TrCallbackObject, Set[TrPublisher]] = {} + for node, cbs in tqdm(nodes_to_cbs.items(), desc="Processing node publications"): if node is None: continue for pub in node.publishers: if any(f in pub.topic_name for f in TOPIC_FILTERS): continue pub_cbs = _get_publishing_cbs(cbs, pub) + for cb in pub_cbs: + if cb not in cb_publishers: + cb_publishers[cb] = set() + cb_publishers[cb].add(pub) if pub.topic not in cbs_publishing_topic: cbs_publishing_topic[pub.topic] = set() cbs_publishing_topic[pub.topic].update(pub_cbs) - return cbs_subbed_to_topic, cbs_publishing_topic + return cbs_subbed_to_topic, cbs_publishing_topic, cb_publishers @dataclass @@ -224,6 +195,7 @@ class LGHierarchyLevel: class LGEdge: start: LGCallback end: LGCallback + topic: TrTopic @dataclass @@ -231,6 +203,9 @@ class LatencyGraph: top_node: LGHierarchyLevel edges: List[LGEdge] + cb_pubs: Dict[TrCallbackObject, Set[TrPublisher]] + pub_cbs: Dict[TrPublisher, Set[TrCallbackObject]] + def __init__(self, tr: TrContext): ################################################## # Annotate nodes with their callbacks @@ -238,9 +213,7 @@ class LatencyGraph: # Note that nodes can also be None! nodes_to_cbs = {} - p = tqdm(desc="Finding CB nodes", total=len(tr.callback_objects)) - for cb in tr.callback_objects.values(): - p.update() + for cb in tqdm(tr.callback_objects, desc="Finding CB nodes"): node = _get_cb_owner_node(cb) if node not in nodes_to_cbs: @@ -251,31 +224,23 @@ class LatencyGraph: # Find in/out topics for each callback ################################################## - cbs_subbed_to_topic, cbs_publishing_topic = _get_cb_topic_deps(nodes_to_cbs) + cbs_subbed_to_topic, cbs_publishing_topic, cb_pubs = _get_cb_topic_deps(nodes_to_cbs) + pub_cbs = {} + for cb, pubs in cb_pubs.items(): + for pub in pubs: + if pub not in pub_cbs: + pub_cbs[pub] = set() + pub_cbs[pub].add(cb) - ################################################## - # Map topics to their messages - ################################################## - - topics_to_messages = {} - p = tqdm(desc="Mapping messages to topics", total=len(tr.publish_instances)) - for pub_inst in tr.publish_instances: - p.update() - try: - topic = pub_inst.publisher.topic - except KeyError: - continue - - if topic not in topics_to_messages: - topics_to_messages[topic] = [] - topics_to_messages[topic].append(pub_inst) + self.cb_pubs = cb_pubs + self.pub_cbs = pub_cbs ################################################## # Define nodes and edges on lowest level ################################################## - input = LGCallback("INPUT", [], [topic for topic in tr.topics.values() if not topic.publishers]) - output = LGCallback("OUTPUT", [topic for topic in tr.topics.values() if not topic.subscriptions], []) + input = LGCallback("INPUT", [], [topic for topic in tr.topics if not topic.publishers]) + output = LGCallback("OUTPUT", [topic for topic in tr.topics if not topic.subscriptions], []) in_node = LGHierarchyLevel(None, [], "INPUT", [input]) out_node = LGHierarchyLevel(None, [], "OUTPUT", [output]) @@ -284,17 +249,17 @@ class LatencyGraph: tr_to_lg_cb = {} - p = tqdm("Building graph nodes", total=sum(map(len, nodes_to_cbs.values()))) + p = tqdm(desc="Building graph nodes", total=sum(map(len, nodes_to_cbs.values()))) for node, cbs in nodes_to_cbs.items(): node_callbacks = [] for cb in cbs: p.update() - try: - sym = cb.callback_symbol + + sym = cb.callback_symbol + if sym is not None: pretty_sym = sanitize(sym.symbol) - except KeyError: - sym = None + else: pretty_sym = cb.id in_topics = [topic for topic, cbs in cbs_subbed_to_topic.items() if cb in cbs] out_topics = [topic for topic, cbs in cbs_publishing_topic.items() if cb in cbs] @@ -306,15 +271,13 @@ class LatencyGraph: lg_nodes.append(lg_node) edges = [] - p = tqdm("Building graph edges", total=len(tr.topics)) - for topic in tr.topics.values(): - p.update() + for topic in tqdm(tr.topics, desc="Building graph edges"): sub_cbs = cbs_subbed_to_topic[topic] if topic in cbs_subbed_to_topic else [] pub_cbs = cbs_publishing_topic[topic] if topic in cbs_publishing_topic else [] for sub_cb in sub_cbs: for pub_cb in pub_cbs: - lg_edge = LGEdge(tr_to_lg_cb[pub_cb], tr_to_lg_cb[sub_cb]) + lg_edge = LGEdge(tr_to_lg_cb[pub_cb], tr_to_lg_cb[sub_cb], topic) edges.append(lg_edge) self.edges = edges @@ -324,6 +287,3 @@ class LatencyGraph: ################################################## self.top_node = _hierarchize(lg_nodes) - - def to_gv(self): - pass diff --git a/latency_graph/message_tree.py b/latency_graph/message_tree.py new file mode 100644 index 0000000..9544bd7 --- /dev/null +++ b/latency_graph/message_tree.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass +from typing import List + +from tracing_interop.tr_types import TrPublishInstance, TrCallbackInstance + + +@dataclass +class DepTree: + head: TrCallbackInstance | TrPublishInstance + deps: List['DepTree'] + + def depth(self): + return 1 + max(map(DepTree.depth, self.deps), default=0) + + def size(self): + return 1 + sum(map(DepTree.size, self.deps)) + + def fanout(self): + if not self.deps: + return 1 + + return sum(map(DepTree.fanout, self.deps)) + + def e2e_lat(self): + return self.head.timestamp - self.critical_path()[-1].timestamp + + def critical_path(self): + if not self.deps: + return [self.head] + + return [self.head, *min(map(DepTree.critical_path, self.deps), key=lambda ls: ls[-1].timestamp)] diff --git a/misc/utils.py b/misc/utils.py index 34eb34e..cd74968 100644 --- a/misc/utils.py +++ b/misc/utils.py @@ -73,7 +73,7 @@ def cached(name, function, file_deps: List[str]): if pkl_time > dep_time: with open(pkl_filename, "rb") as f: - print(f"[CACHE] Found up-to-date cache entry for {name}, loading.") + print(f"[CACHE] Found up-to-date cache entry ({pkl_filename}) for {name}, loading.") return pickle.load(f) if os.path.exists(pkl_filename): diff --git a/requirements.txt b/requirements.txt index 76e0bb5..90e2676 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ matplotlib pyvis graphviz ruamel.yaml -fuzzywuzzy +blist @ git+https://github.com/mojomex/blist.git@47724cbc4137ddfb685f9711e950fb82587bf971 + diff --git a/trace-analysis.ipynb b/trace-analysis.ipynb index 459b994..7fd7fb9 100644 --- a/trace-analysis.ipynb +++ b/trace-analysis.ipynb @@ -10,9 +10,11 @@ }, "outputs": [], "source": [ + "import glob\n", "import json\n", "import os\n", "import pickle\n", + "import re\n", "import sys\n", "\n", "import numpy as np\n", @@ -29,12 +31,12 @@ "from tracetools_analysis.processor.ros2 import Ros2Handler\n", "from tracetools_analysis.utils.ros2 import Ros2DataModelUtil\n", "\n", - "from dataclasses import dataclass\n", - "from typing import List, Dict, Set, Tuple\n", - "\n", "from tracing_interop.tr_types import TrTimer, TrTopic, TrPublisher, TrPublishInstance, TrCallbackInstance, \\\n", "TrCallbackSymbol, TrCallbackObject, TrSubscriptionObject, TrContext\n", - "from misc.utils import ProgressPrinter, cached" + "from misc.utils import ProgressPrinter, cached\n", + "\n", + "%load_ext pyinstrument\n", + "%matplotlib inline" ] }, { @@ -47,7 +49,7 @@ }, "outputs": [], "source": [ - "TR_PATH = os.path.expanduser(\"data/trace-awsim-x86/ust\")\n", + "TR_PATH = os.path.expanduser(\"data/awsim-trace/ust\")\n", "CL_PATH = os.path.expanduser(\"~/Projects/llvm-project/clang-tools-extra/ros2-internal-dependency-checker/output\")" ] }, @@ -71,8 +73,7 @@ "def _load_traces():\n", " file = load_file(TR_PATH)\n", " handler = Ros2Handler.process(file)\n", - " util = Ros2DataModelUtil(handler)\n", - " return TrContext(util, handler)\n", + " return TrContext(handler)\n", "\n", "\n", "_tracing_context = cached(\"tr_objects\", _load_traces, [TR_PATH])\n", @@ -94,56 +95,6 @@ } } }, - { - "cell_type": "markdown", - "source": [ - "# ROS2 Tracing & Clang Matching" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } - } - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [ - "def _load_cl_objects():\n", - " return process_clang_output(CL_PATH)\n", - "\n", - "\n", - "_cl_context: ClContext = cached(\"cl_objects\", _load_cl_objects, [CL_PATH])" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [ - "from matching.subscriptions import cl_deps_to_tr_deps, match\n", - "\n", - "matches, tr_unmatched, cl_unmatched = match(_tracing_context, _cl_context)\n", - "tr_internal_deps = cl_deps_to_tr_deps(matches, _tracing_context, _cl_context)\n", - "tr_cl_matches = {tup[1]: tup[0] for tup in matches}\n", - "\n", - "print(len(tr_internal_deps), sum(map(len, tr_internal_deps.values())))" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } - }, { "cell_type": "markdown", "source": [ @@ -163,14 +114,28 @@ "source": [ "from latency_graph import latency_graph as lg\n", "\n", - "#lat_graph = lg.LatencyGraph(_tracing_context)\n", + "lat_graph = lg.LatencyGraph(_tracing_context)\n", "\n", "import pickle\n", "\n", - "#with open(\"lat_graph.pkl\", \"wb\") as f:\n", - "# pickle.dump(lat_graph, f)\n", - "with open(\"lat_graph.pkl\", \"rb\") as f:\n", - " lat_graph = pickle.load(f)" + "with open(\"lat_graph.pkl\", \"wb\") as f:\n", + " pickle.dump(lat_graph, f)\n", + "#with open(\"lat_graph.pkl\", \"rb\") as f:\n", + "# lat_graph = pickle.load(f)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "len(lat_graph.edges)" ], "metadata": { "collapsed": false, @@ -237,14 +202,17 @@ " if isinstance(cb, lg.LGTrCallback):\n", " tr_cb = cb.cb\n", " try:\n", - " sym = _tracing_context.callback_symbols.get(tr_cb.callback_object)\n", - " pretty_sym = repr(sanitize(sym.symbol)).replace(\"&\", \"&\").replace(\"<\", \"<\").replace(\">\", \">\")\n", + " sym = _tracing_context.callback_symbols.by_id.get(tr_cb.callback_object)\n", + " pretty_sym = repr(sanitize(sym.symbol))\n", " except KeyError:\n", " pretty_sym = cb.name\n", " except TypeError:\n", " pretty_sym = cb.name\n", " else:\n", " pretty_sym = cb.name\n", + "\n", + " pretty_sym = pretty_sym.replace(\"&\", \"&\").replace(\"<\", \"<\").replace(\">\", \">\")\n", + "\n", " c.node(cb.id(),\n", " f'<
{pretty_sym}
>')\n", "\n", @@ -278,11 +246,19 @@ "execution_count": null, "outputs": [], "source": [ + "import re\n", + "import math\n", + "\n", "##################################################\n", "# Compute in/out topics for hierarchy level X\n", "##################################################\n", "\n", - "HIER_LEVEL = 2\n", + "HIER_LEVEL = 100\n", + "\n", + "input_node_patterns = [r\"^/sensing\"]\n", + "output_node_patterns = [r\"^/awapi\", r\"^/control/external_cmd_converter\"]\n", + "\n", + "node_excluded_patterns = [r\"^/rviz2\", r\"transform_listener_impl\"]\n", "\n", "def get_nodes_on_level(lat_graph: lg.LatencyGraph):\n", " def _traverse_node(node: lg.LGHierarchyLevel, cur_lvl=0):\n", @@ -300,8 +276,13 @@ " return _traverse_node(lat_graph.top_node)\n", "\n", "lvl_nodes = get_nodes_on_level(lat_graph)\n", - "lvl_nodes = [n for n in lvl_nodes if \"transform_listener_impl\" not in n.full_name]\n", + "lvl_nodes = [n for n in lvl_nodes if not any(re.search(p, n.full_name) for p in node_excluded_patterns)]\n", "\n", + "input_nodes = [n.full_name for n in lvl_nodes if any(re.search(p, n.full_name) for p in input_node_patterns)]\n", + "output_nodes = [n.full_name for n in lvl_nodes if any(re.search(p, n.full_name) for p in output_node_patterns)]\n", + "\n", + "print(', '.join(map(lambda n: n, input_nodes)))\n", + "print(', '.join(map(lambda n: n, output_nodes)))\n", "print(', '.join(map(lambda n: n.full_name, lvl_nodes)))\n", "\n", "def _collect_callbacks(n: lg.LGHierarchyLevel):\n", @@ -333,9 +314,9 @@ " k = (from_node.full_name, to_node.full_name)\n", "\n", " if k not in edges_between_nodes:\n", - " edges_between_nodes[k] = 0\n", + " edges_between_nodes[k] = []\n", "\n", - " edges_between_nodes[k] += 1\n", + " edges_between_nodes[k].append(edge)\n", "\n", "g = gv.Digraph('G', filename=\"latency_graph.gv\",\n", " node_attr={'shape': 'plain'},\n", @@ -344,11 +325,45 @@ "\n", "for n in lvl_nodes:\n", " colors = node_colors[node_namespace_mapping.get(n.full_name.strip(\"/\").split(\"/\")[0])]\n", - " g.node(n.full_name, label=n.full_name, fillcolor=colors[\"fill\"], color=colors[\"stroke\"], shape=\"box\", style=\"filled\")\n", + " peripheries = \"1\" if n.full_name not in output_nodes else \"2\"\n", + " g.node(n.full_name, label=n.full_name, fillcolor=colors[\"fill\"], color=colors[\"stroke\"],\n", + " shape=\"box\", style=\"filled\", peripheries=peripheries)\n", "\n", - "for (src_name, dst_name), cnt in edges_between_nodes.items():\n", - " print(src_name, dst_name, cnt)\n", - " g.edge(src_name, dst_name, weight=str(cnt))\n", + " if n.full_name in input_nodes:\n", + " helper_node_name = f\"{n.full_name}__before\"\n", + " g.node(helper_node_name, label=\"\", shape=\"none\", height=\"0\", width=\"0\")\n", + " g.edge(helper_node_name, n.full_name)\n", + "\n", + "def compute_e2e_paths(start_nodes, end_nodes, edges):\n", + " frontier_paths = [[n] for n in start_nodes]\n", + " final_paths = []\n", + "\n", + " while frontier_paths:\n", + " frontier_paths_new = []\n", + "\n", + " for path in frontier_paths:\n", + " head = path[-1]\n", + " if head in end_nodes:\n", + " final_paths.append(path)\n", + " continue\n", + "\n", + " out_nodes = [n_to for n_from, n_to in edges if n_from == head if n_to not in path]\n", + " new_paths = [path + [n] for n in out_nodes]\n", + " frontier_paths_new += new_paths\n", + "\n", + " frontier_paths = frontier_paths_new\n", + "\n", + " final_paths = [[(n_from, n_to)\n", + " for n_from, n_to in zip(path[:-1], path[1:])]\n", + " for path in final_paths]\n", + " return final_paths\n", + "\n", + "e2e_paths = compute_e2e_paths(input_nodes, output_nodes, edges_between_nodes)\n", + "\n", + "for (src_name, dst_name), edges in edges_between_nodes.items():\n", + " print(src_name, dst_name, len(edges))\n", + " color = \"black\" if any((src_name, dst_name) in path for path in e2e_paths) else \"tomato\"\n", + " g.edge(src_name, dst_name, penwidth=str(math.log(len(edges)) * 2 + .2), color=color)\n", "\n", "g.save(\"level_graph.gv\")\n", "g.render(\"level_graph.svg\")\n", @@ -362,6 +377,838 @@ } } }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "from latency_graph.message_tree import DepTree\n", + "from tqdm.notebook import tqdm\n", + "from bisect import bisect\n", + "\n", + "topic_name_filter = [\"/control/trajectory_follower/control_cmd\"]\n", + "\n", + "\n", + "def inst_get_dep_msg(inst: TrCallbackInstance):\n", + " if inst.callback_object not in _tracing_context.callback_objects.by_callback_object:\n", + " # print(\"Callback not found (2)\")\n", + " return None\n", + "\n", + " if not isinstance(inst.callback_obj.owner, TrSubscriptionObject):\n", + " # print(f\"Wrong type: {type(inst.callback_obj.owner)}\")\n", + " return None\n", + "\n", + " sub_obj: TrSubscriptionObject = inst.callback_obj.owner\n", + " if sub_obj and sub_obj.subscription and sub_obj.subscription.topic:\n", + " # print(f\"Subscription has no topic\")\n", + " pubs = sub_obj.subscription.topic.publishers\n", + " else:\n", + " pubs = []\n", + "\n", + " def _pub_latest_msg_before(pub: TrPublisher, inst):\n", + " i_latest_msg = bisect(pub.instances, inst.timestamp, key=lambda x: x.timestamp) - 1\n", + " if i_latest_msg < 0 or i_latest_msg >= len(pub.instances):\n", + " return None\n", + " latest_msg = pub.instances[i_latest_msg]\n", + " if latest_msg.timestamp >= inst.timestamp:\n", + " return None\n", + "\n", + " return latest_msg\n", + "\n", + " msgs = [_pub_latest_msg_before(pub, inst) for pub in pubs]\n", + " msgs = [msg for msg in msgs if msg is not None]\n", + " msgs.sort(key=lambda i: i.timestamp, reverse=True)\n", + " if msgs:\n", + " msg = msgs[0]\n", + " return msg\n", + "\n", + " # print(f\"No messages found for topic {sub_obj.subscription.topic}\")\n", + " return None\n", + "\n", + "def inst_get_dep_insts(inst: TrCallbackInstance):\n", + " if inst.callback_object not in _tracing_context.callback_objects.by_callback_object:\n", + " # print(\"Callback not found\")\n", + " return []\n", + " dep_cbs = get_cb_dep_cbs(inst.callback_obj)\n", + "\n", + " def _cb_to_chronological_inst(cb: TrCallbackObject, inst):\n", + " i_inst_latest = bisect(cb.callback_instances, inst.timestamp, key=lambda x: x.timestamp)\n", + "\n", + " for inst_before in cb.callback_instances[i_inst_latest::-1]:\n", + " if lg.inst_runtime_interval(inst_before)[-1] < inst.timestamp:\n", + " return inst_before\n", + "\n", + " return None\n", + "\n", + " insts = [_cb_to_chronological_inst(cb, inst) for cb in dep_cbs]\n", + " insts = [inst for inst in insts if inst is not None]\n", + " return insts\n", + "\n", + "def get_cb_dep_cbs(cb: TrCallbackObject):\n", + " match cb.owner:\n", + " case TrSubscriptionObject() as sub_obj:\n", + " sub_obj: TrSubscriptionObject\n", + " owner = sub_obj.subscription.node\n", + " case TrTimer() as tmr:\n", + " tmr: TrTimer\n", + " owner = tmr.node\n", + " case _:\n", + " raise RuntimeError(f\"Encountered {cb.owner} as callback owner\")\n", + "\n", + " owner: TrNode\n", + " dep_sub_objs = {sub.subscription_object for sub in owner.subscriptions}\n", + " dep_cbs = {callback_objects.by_id.get(sub_obj.id) for sub_obj in dep_sub_objs if sub_obj is not None}\n", + " dep_cbs |= {callback_objects.by_id.get(tmr.id) for tmr in owner.timers}\n", + " dep_cbs.discard(cb)\n", + " dep_cbs.discard(None)\n", + "\n", + " return dep_cbs\n", + "\n", + "\n", + "def get_msg_dep_cb(msg: TrPublishInstance):\n", + " \"\"\"\n", + " For a given message instance `msg`, find the publishing callback,\n", + " as well as the message instances that callback depends on (transitively within its TrNode).\n", + " \"\"\"\n", + "\n", + " # Find CB instance that published msg\n", + " # print(f\"PUB {msg.publisher.node.path if msg.publisher.node is not None else '??'} ==> {msg.publisher.topic_name}\")\n", + " pub_cbs = lat_graph.pub_cbs.get(msg.publisher)\n", + " if pub_cbs is None:\n", + " # print(\"Publisher unknown to lat graph. Skipping.\")\n", + " return None\n", + "\n", + " # print(f\"Found {len(pub_cbs)} pub cbs\")\n", + " cb_inst_candidates = []\n", + " for cb in pub_cbs:\n", + " # print(f\" > CB ({len(cb.callback_instances)} instances): {cb.callback_symbol.symbol if cb.callback_symbol else cb.id}\")\n", + " i_inst_after = bisect(cb.callback_instances, msg.timestamp, key=lambda x: x.timestamp)\n", + "\n", + " for inst in cb.callback_instances[:i_inst_after]:\n", + " inst_start, inst_end = lg.inst_runtime_interval(inst)\n", + " if msg.timestamp > inst_end:\n", + " continue\n", + "\n", + " assert inst_start <= msg.timestamp <= inst_end\n", + "\n", + " cb_inst_candidates.append(inst)\n", + "\n", + " if len(cb_inst_candidates) > 1:\n", + " # print(\"Found multiple possible callbacks\")\n", + " return None\n", + " if not cb_inst_candidates:\n", + " # print(\"Found no possible callbacks\")\n", + " return None\n", + "\n", + " dep_inst = cb_inst_candidates[0]\n", + " return dep_inst\n", + "\n", + "\n", + "def get_dep_tree(inst: TrPublishInstance | TrCallbackInstance, lvl=0, visited_topics=None, is_dep_cb=False):\n", + " if visited_topics is None:\n", + " visited_topics = set()\n", + "\n", + " children_are_dep_cbs = False\n", + "\n", + " match inst:\n", + " case TrPublishInstance(publisher=pub):\n", + " if pub.topic_name in visited_topics:\n", + " return None\n", + "\n", + " visited_topics.add(pub.topic_name)\n", + " deps = [get_msg_dep_cb(inst)]\n", + " case TrCallbackInstance() as cb_inst:\n", + " deps = [inst_get_dep_msg(cb_inst)]\n", + " if not is_dep_cb:\n", + " deps += inst_get_dep_insts(cb_inst)\n", + " children_are_dep_cbs = True\n", + " case _:\n", + " raise TypeError(f\"Expected inst to be of type TrPublishInstance or TrCallbackInstance, got {type(inst).__name__}\")\n", + "\n", + " # print(\"Rec level\", lvl)\n", + " deps = [dep for dep in deps if dep is not None]\n", + " deps = [get_dep_tree(dep, lvl + 1, set(visited_topics), is_dep_cb=children_are_dep_cbs) for dep in deps]\n", + " deps = [dep for dep in deps if dep is not None]\n", + " return DepTree(inst, deps)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "\n", + "#for path in e2e_topic_paths:\n", + "# end_topics = path[-1]\n", + "end_topics = [t for t in _tracing_context.topics if any(f in t.name for f in topic_name_filter)]\n", + "for end_topic in end_topics:\n", + " end_topic: TrTopic\n", + "\n", + " pubs = end_topic.publishers\n", + " for pub in pubs:\n", + " depths = []\n", + " sizes = []\n", + " e2e_lats = []\n", + " trees = []\n", + " msgs = pub.instances\n", + " for msg in tqdm(msgs, desc=f\"Building message chains for topic {end_topic.name}\"):\n", + " msg: TrPublishInstance\n", + " tree = get_dep_tree(msg)\n", + " depths.append(tree.depth())\n", + " sizes.append(tree.size())\n", + " e2e_lats.append(tree.e2e_lat())\n", + " trees.append(tree)\n", + " if depths:\n", + " print(f\"Depth: min={min(depths)} avg={sum(depths) / len(depths)} max={max(depths)}\")\n", + " print(f\"Size: min={min(sizes)} avg={sum(sizes) / len(sizes)} max={max(sizes)}\")\n", + " print(f\"E2E Lat: min={min(e2e_lats)*1000:.3f}ms avg={sum(e2e_lats) / len(sizes)*1000:.3f}ms max={max(e2e_lats)*1000:.3f}ms\")\n", + "\n", + "with open(\"trees.pkl\", \"wb\") as f:\n", + " pickle.dump(trees, f)\n" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "with open(\"trees.pkl\", \"rb\") as f:\n", + " trees = pickle.load(f)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "import glob\n", + "\n", + "\n", + "def parse_bytes(string):\n", + " match string[-1]:\n", + " case 'K':\n", + " exponent = 1e3\n", + " case 'M':\n", + " exponent = 1e6\n", + " case _:\n", + " exponent = 1\n", + "\n", + " num = float(string.split(\" \")[0])\n", + " return num * exponent\n", + "\n", + "\n", + "def bytes_str(bytes):\n", + " if bytes >= 1024**2:\n", + " return f\"{bytes/(1024**2):.2f} MiB\"\n", + " if bytes >= 1024:\n", + " return f\"{bytes/1024:.2f} KiB\"\n", + " return f\"{bytes:.0f} B\"\n", + "\n", + "\n", + "BW_PATH = \"../ma-hw-perf-tools/data/results\"\n", + "bw_files = glob.glob(os.path.join(BW_PATH, \"*.log\"))\n", + "msg_sizes = {}\n", + "for bw_file in bw_files:\n", + " with open(bw_file) as f:\n", + " lines = f.readlines()\n", + " topic = os.path.splitext(os.path.split(bw_file)[1])[0].replace(\"__\", \"/\")\n", + "\n", + " if not lines or re.match(f\"^\\s*$\", lines[-1]):\n", + " #print(f\"No data for {topic}\")\n", + " continue\n", + "\n", + " line_pattern = re.compile(r\"(?P[0-9.]+ [KM]?)B/s from (?P[0-9.]+) messages --- Message size mean: (?P[0-9.]+ [KM]?)B min: (?P[0-9.]+ [KM]?)B max: (?P[0-9.]+ [KM]?)B\\n\")\n", + " m = re.fullmatch(line_pattern, lines[-1])\n", + " if m is None:\n", + " print(f\"Line could not be parsed in {topic}: '{lines[-1]}'\")\n", + " continue\n", + "\n", + " msg_sizes[topic] = {'bw': parse_bytes(m.group(\"bw\")),\n", + " 'min': parse_bytes(m.group(\"min\")),\n", + " 'mean': parse_bytes(m.group(\"mean\")),\n", + " 'max': parse_bytes(m.group(\"max\"))}\n" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "from typing import List\n", + "from latency_graph.message_tree import DepTree\n", + "\n", + "start_topic_filters = [\"/vehicle/status/\", \"/sensing/imu\"]\n", + "\n", + "def leaf_topics(tree: DepTree, lvl=0):\n", + " ret_list = []\n", + " match tree.head:\n", + " case TrPublishInstance(publisher=pub):\n", + " if pub:\n", + " ret_list += [(lvl, pub.topic_name)]\n", + " ret_list += [(lvl, None)]\n", + "\n", + " for dep in tree.deps:\n", + " ret_list += leaf_topics(dep, lvl+1)\n", + " return ret_list\n", + "\n", + "#all_topics = set()\n", + "\n", + "#for tree in trees:\n", + "# for d, t in leaf_topics(tree):\n", + "# if t in [\"/parameter_events\", \"/clock\"]:\n", + "# continue\n", + "# all_topics.add(t)\n", + "\n", + "#def critical_path(self: DepTree, start_topic_filters: List[str]):\n", + "# if not self.deps:\n", + "# return [self.head]\n", + "#\n", + "# return [self.head, *max(map(DepTree.critical_path, self.deps), key=lambda ls: ls[-1].timestamp)]\n", + "\n", + "E2E_TIME_LIMIT_S = 2\n", + "\n", + "def all_e2es(tree: DepTree, t_start=None):\n", + " if t_start is None:\n", + " t_start = tree.head.timestamp\n", + "\n", + " if not tree.deps:\n", + " return [t_start - tree.head.timestamp]\n", + "\n", + " ret_list = []\n", + " for dep in tree.deps:\n", + " ret_list += all_e2es(dep, t_start)\n", + " return ret_list\n", + "\n", + "def relevant_e2es(tree: DepTree, start_topic_filters, t_start=None, path=None):\n", + " if t_start is None:\n", + " t_start = tree.head.timestamp\n", + "\n", + " if path is None:\n", + " path = []\n", + "\n", + " latency = t_start - tree.head.timestamp\n", + " if latency > E2E_TIME_LIMIT_S:\n", + " return []\n", + "\n", + " new_path = [tree.head] + path\n", + "\n", + " if not tree.deps:\n", + " match tree.head:\n", + " case TrPublishInstance(publisher=pub):\n", + " if pub and any(f in pub.topic_name for f in start_topic_filters):\n", + " return [(latency,new_path)]\n", + "\n", + " ret_list = []\n", + " for dep in tree.deps:\n", + " ret_list += relevant_e2es(dep, start_topic_filters, t_start, new_path)\n", + " return ret_list\n", + "\n", + "\n", + "e2ess = []\n", + "e2e_pathss = []\n", + "for tree in trees:\n", + " e2es, e2e_paths = zip(*relevant_e2es(tree, start_topic_filters))\n", + " e2ess.append(e2es)\n", + " e2e_pathss.append(e2e_paths)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "from matplotlib.animation import FuncAnimation\n", + "from IPython import display\n", + "\n", + "fig, ax = plt.subplots(figsize=(16, 9))\n", + "ax: plt.Axes\n", + "ax.set_xlim(0, 4)\n", + "\n", + "ax.hist([], bins=200, range=(0, 4), histtype='stepfilled')\n", + "ax.set_title(\"Time: 0.000000s\")\n", + "\n", + "def anim(frame):\n", + " print(frame, end='\\r')\n", + " ax.clear()\n", + " ax.hist(e2es[frame], bins=200, range=(0, 4), histtype='stepfilled')\n", + " ax.set_title(f\"Time: {(trees[frame].head.timestamp - trees[0].head.timestamp):.6}s\")\n", + "\n", + "\n", + "anim_created = FuncAnimation(fig, anim, min(len(trees), 10000), interval=16, repeat_delay=200)\n", + "\n", + "video = anim_created.save(\"anim.mp4\", dpi=120)\n", + "\n", + "#for tree in trees:\n", + "# path = tree.critical_path(start_topic_filters)\n", + "# for i, inst in enumerate(path[::-1]):\n", + "# match inst:\n", + "# case TrPublishInstance(publisher=pub):\n", + "# print(f\" {i:>3d}: T\", pub.topic_name)\n", + "# case TrCallbackInstance(callback_obj=cb):\n", + "# match cb.owner:\n", + "# case TrSubscriptionObject(subscription=sub):\n", + "# node = sub.node\n", + "# case TrTimer() as tmr:\n", + "# node = tmr.node\n", + "# case _:\n", + "# raise ValueError(f\"Callback owner type not recognized: {type(cb.owner).__name__}\")\n", + "#\n", + "# print(f\" {i:>3d}: N\", node.path)\n", + "# print(\"==================\")\n" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(18, 9), num=\"e2e_plot\")\n", + "DS=1\n", + "times = [tree.head.timestamp - trees[0].head.timestamp for tree in trees[::DS]]\n", + "\n", + "ax2 = ax.twinx()\n", + "ax2.plot(times, list(map(lambda paths: sum(map(len, paths)) / len(paths), e2e_pathss[::DS])), color=\"orange\")\n", + "ax2.fill_between(times,\n", + " list(map(lambda paths: min(map(len, paths)), e2e_pathss[::DS])),\n", + " list(map(lambda paths: max(map(len, paths)), e2e_pathss[::DS])),\n", + " alpha=.3, color=\"orange\")\n", + "\n", + "ax.plot(times, [np.mean(e2es) for e2es in e2ess[::DS]])\n", + "ax.fill_between(times, [np.min(e2es) for e2es in e2ess[::DS]], [np.max(e2es) for e2es in e2ess[::DS]], alpha=.3)\n", + "\n", + "def scatter_topic(topic_name, y=0, **scatter_kwargs):\n", + " for pub in topics.by_name[topic_name].publishers:\n", + " if not pub:\n", + " continue\n", + "\n", + " inst_timestamps = [inst.timestamp - trees[0].head.timestamp for inst in pub.instances if inst.timestamp >= trees[0].head.timestamp]\n", + " scatter_kwargs_default = {\"marker\": \"x\", \"color\": \"indianred\"}\n", + " scatter_kwargs_default.update(scatter_kwargs)\n", + " ax.scatter(inst_timestamps, np.full(len(inst_timestamps), fill_value=y), **scatter_kwargs_default)\n", + "\n", + "scatter_topic(\"/autoware/engage\")\n", + "scatter_topic(\"/planning/scenario_planning/parking/trajectory\", y=-.04, color=\"cadetblue\")\n", + "scatter_topic(\"/planning/scenario_planning/lane_driving/trajectory\", y=-.08, color=\"darkgreen\")\n", + "scatter_topic(\"/initialpose2d\", y=-.12, color=\"orange\")\n", + "\n", + "ax.set_xlabel(\"Simulation time [s]\")\n", + "ax.set_ylabel(\"End-to-End latency [s]\")\n", + "ax2.set_ylabel(\"End-to-End path length\")\n", + "None" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "def critical_path(self):\n", + " return [self.head, *min(map(critical_path, self.deps), key=lambda ls: ls[-1].timestamp)]\n", + "\n", + "\n", + "def e2e_lat(self):\n", + " return self.head.timestamp - critical_path(self)[-1].timestamp\n", + "\n", + "\n", + "def get_relevant_tree(tree: DepTree, accept_leaf=lambda x: True, root=None):\n", + " if root is None:\n", + " root = tree.head\n", + " if not tree.deps:\n", + " if accept_leaf(tree.head, root):\n", + " return tree\n", + " return None\n", + "\n", + " relevant_deps = [get_relevant_tree(dep, accept_leaf, root) for dep in tree.deps]\n", + " if not any(relevant_deps):\n", + " return None\n", + "\n", + " return DepTree(tree.head, [dep for dep in relevant_deps if dep])\n", + "\n", + "\n", + "def fanout(self):\n", + " if not self.deps:\n", + " return 1\n", + "\n", + " return sum(map(fanout, self.deps))\n", + "\n", + "\n", + "def sort_subtree(subtree: DepTree, sort_func=lambda t: t.head.timestamp - e2e_lat(t)):\n", + " subtree.deps.sort(key=sort_func)\n", + " for dep in subtree.deps:\n", + " sort_subtree(dep, sort_func)\n", + " return subtree\n", + "\n", + "def _leaf_filter(inst, root):\n", + " if root.timestamp - inst.timestamp > E2E_TIME_LIMIT_S:\n", + " return False\n", + "\n", + " match inst:\n", + " case TrPublishInstance(publisher=pub):\n", + " return pub and any(f in pub.topic_name for f in start_topic_filters)\n", + " return False\n", + "\n", + "\n", + "relevant_trees = [get_relevant_tree(tree, _leaf_filter) for tree in trees]" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "from cycler import cycler\n", + "\n", + "def dict_safe_append(dictionary, key, value):\n", + " if key not in dictionary:\n", + " dictionary[key] = []\n", + " dictionary[key].append(value)\n", + "\n", + "\n", + "fig, (ax, ax_rel) = plt.subplots(2, 1, sharex=True, figsize=(60, 30), num=\"crit_plot\")\n", + "ax.set_prop_cycle(cycler('color', [plt.cm.nipy_spectral(i/4) for i in range(5)]))\n", + "ax_rel.set_prop_cycle(cycler('color', [plt.cm.nipy_spectral(i/4) for i in range(5)]))\n", + "\n", + "critical_paths = [critical_path(tree) for tree in relevant_trees[::DS]]\n", + "\n", + "time_breakdown = {}\n", + "\n", + "for path in critical_paths:\n", + " tmr_cb_calc_time = 0.0\n", + " sub_cb_calc_time = 0.0\n", + " tmr_cb_relevant_time = 0.0\n", + " sub_cb_relevant_time = 0.0\n", + " dds_time = 0.0\n", + " idle_time = 0.0\n", + "\n", + " last_pub_time = None\n", + " last_cb_time = None\n", + " for inst in path:\n", + " match inst:\n", + " case TrPublishInstance(timestamp=t):\n", + " assert last_pub_time is None, \"Two publication without callback inbetween\"\n", + "\n", + " if last_cb_time is not None:\n", + " dds_time += last_cb_time - t\n", + "\n", + " last_pub_time = t\n", + " last_cb_time = None\n", + " case TrCallbackInstance(callback_obj=cb, timestamp=t, duration=d):\n", + " if last_pub_time is not None:\n", + " assert last_pub_time <= t+d, \"Publication out of CB instance timeframe\"\n", + "\n", + " match cb.owner:\n", + " case TrTimer():\n", + " tmr_cb_calc_time += d\n", + " tmr_cb_relevant_time += last_pub_time - t\n", + " case TrSubscriptionObject():\n", + " sub_cb_calc_time += d\n", + " sub_cb_relevant_time += last_pub_time - t\n", + " elif last_cb_time is not None:\n", + " idle_time += last_cb_time - (t + d)\n", + " last_pub_time = None\n", + " last_cb_time = t\n", + "\n", + " #dict_safe_append(time_breakdown, \"tmr_cb_calc_time\", tmr_cb_calc_time)\n", + " #dict_safe_append(time_breakdown, \"sub_cb_calc_time\", sub_cb_calc_time)\n", + " dict_safe_append(time_breakdown, \"Timer CB\", tmr_cb_relevant_time)\n", + " dict_safe_append(time_breakdown, \"Subscription CB\", sub_cb_relevant_time)\n", + " dict_safe_append(time_breakdown, \"DDS\", dds_time)\n", + " dict_safe_append(time_breakdown, \"Idle\", idle_time)\n", + "\n", + "time_breakdown = {k: np.array(v) for k, v in time_breakdown.items()}\n", + "\n", + "timer_cb_times = [sum(inst.duration for inst in path if isinstance(inst, TrCallbackInstance) and isinstance(inst.callback_obj, TrTimer)) for path in critical_paths]\n", + "sub_cb_times = [sum(inst.duration for inst in path if isinstance(inst, TrCallbackInstance)) for path in critical_paths]\n", + "\n", + "labels, values = list(zip(*time_breakdown.items()))\n", + "\n", + "#ax.plot(range(len(relevant_trees[::DS])), [e2e_lat(tree) for tree in relevant_trees[::DS]], label=\"Total E2E\")\n", + "ax.stackplot(range(len(relevant_trees[::DS])), values, labels=labels)\n", + "ax.legend()\n", + "ax.set_title(\"End-to-End Latency Breakdown\")\n", + "ax.set_ylabel(\"End-to-End Latency [s]\")\n", + "\n", + "timestep_mags = np.array([sum(vs) for vs in zip(*values)])\n", + "ax_rel.stackplot(range(len(relevant_trees[::DS])), [val / timestep_mags for val in values], labels=labels)\n", + "ax_rel.set_title(\"End-to-End Latency Breakdown (relative)\")\n", + "ax_rel.set_ylabel(\"End-to-End Latency Fraction\")\n", + "ax_rel.set_xlabel(\"Timestep\")\n", + "ax_rel.legend()\n", + "\n", + "None" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "from scipy import stats\n", + "\n", + "fig, ax = plt.subplots(figsize=(60, 15), num=\"crit_pdf\")\n", + "ax.set_prop_cycle(cycler('color', [plt.cm.nipy_spectral(i/4) for i in range(5)]))\n", + "\n", + "kde = stats.gaussian_kde(timestep_mags)\n", + "xs = np.linspace(values.min(), values.max(), 1000)\n", + "ax.plot(xs, kde(xs), label=\"End-to-End Latency\")\n", + "perc = 90\n", + "ax.axvline(np.percentile(timestep_mags, perc), label=f\"{perc}th percentile\")\n", + "\n", + "ax2 = ax.twinx()\n", + "ax2.hist(timestep_mags, 200)\n", + "ax2.set_ylim(0, ax2.get_ylim()[1])\n", + "\n", + "ax.set_title(\"Time Distribution for E2E Breakdown\")\n", + "ax.set_xlabel(\"Time [s]\")\n", + "ax.set_ylabel(\"Frequency\")\n", + "ax.set_xlim(0, 2.01)\n", + "ax.set_ylim(0, ax.get_ylim()[1])\n", + "ax.legend()\n", + "\n", + "None" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "from tracing_interop.tr_types import TrSubscription\n", + "from matching.subscriptions import sanitize\n", + "import matplotlib.patches as mpatch\n", + "from matplotlib.text import Text\n", + "import math\n", + "\n", + "i = 3900\n", + "tree = trees[i]\n", + "e2es = e2ess[i]\n", + "e2e_paths = e2e_pathss[i]\n", + "margin_y = .2\n", + "margin_x=0\n", + "arr_width= 1 - margin_y\n", + "\n", + "def cb_str(inst: TrCallbackInstance):\n", + " cb: TrCallbackObject = inst.callback_obj\n", + " if not cb:\n", + " return None\n", + " ret_str = f\"- {inst.duration*1e3:07.3f}ms \"\n", + " #if cb.callback_symbol:\n", + " # ret_str = repr(sanitize(cb.callback_symbol.symbol))\n", + "\n", + " match cb.owner:\n", + " case TrSubscriptionObject(subscription=sub):\n", + " sub: TrSubscription\n", + " ret_str = f\"{ret_str}{sub.node.path if sub.node else None} <- {sub.topic_name}\"\n", + " case TrTimer(period=p, node=node):\n", + " p: int\n", + " node: TrNode\n", + " ret_str = f\"{ret_str}{node.path if node else None} <- @{1/(p*1e-9):.2f}Hz\"\n", + " return ret_str\n", + "\n", + "\n", + "def ts_str(inst, prev):\n", + " return f\"{(inst.timestamp - prev)*1e3:+09.3f}ms\"\n", + "\n", + "def bw_str(bw):\n", + " return bytes_str(bw['mean'])\n", + "\n", + "def trunc_chars(string, w, e2e):\n", + " if w < e2e * .005:\n", + " return \"\"\n", + " n_chars = max(math.floor(w / (e2e * .17) * 65) - 5, 0)\n", + " if n_chars < 4:\n", + " return \"\"\n", + "\n", + " return \"...\" + string[-n_chars:] if n_chars < len(string) else string\n", + "\n", + "#e2e_paths = sorted(e2e_paths, key=lambda path: path[-1].timestamp - path[0].timestamp, reverse=True)\n", + "#for y, e2e_path in reversed(list(enumerate(e2e_paths))):\n", + "# last_pub_ts = None\n", + "# last_cb_end = None\n", + "# print(f\"=== {y}:\")\n", + "# for inst in e2e_path:\n", + "\n", + "tree = sort_subtree(get_relevant_tree(tree, _leaf_filter))\n", + "\n", + "t_start = tree.head.timestamp\n", + "t_min = t_start - e2e_lat(tree)\n", + "t_e2e = t_start - t_min\n", + "\n", + "legend_entries = {}\n", + "\n", + "def plot_subtree(subtree: DepTree, ax: plt.Axes, y_labels, y=0, next_cb_start=0):\n", + " height = fanout(subtree)\n", + " inst = subtree.head\n", + "\n", + " match inst:\n", + " case TrCallbackInstance(timestamp=t_cb, duration=d_cb):\n", + " is_sub = isinstance(inst.callback_obj.owner, TrSubscriptionObject)\n", + "\n", + " r_x = t_cb - t_start + margin_x / 2\n", + " r_y = y + margin_y / 2\n", + " r_w = max(d_cb - margin_x, 0)\n", + " r_h = height - margin_y\n", + "\n", + " r = mpatch.Rectangle((r_x, r_y), r_w, r_h,\n", + " ec=\"cadetblue\" if is_sub else \"indianred\", fc=\"lightblue\" if is_sub else \"lightcoral\", zorder=9000)\n", + " ax.add_artist(r)\n", + "\n", + " text = repr(sanitize(inst.callback_obj.callback_symbol.symbol)) if inst.callback_obj and inst.callback_obj.callback_symbol else \"??\"\n", + " text = trunc_chars(text, r_w, t_e2e)\n", + " if text:\n", + " ax.text(r_x + r_w / 2, r_y + r_h / 2, text, ha=\"center\", va=\"center\", backgroundcolor=(1,1,1,.5), zorder=11000)\n", + "\n", + " if is_sub and \"Subscription CB\" not in legend_entries:\n", + " legend_entries[\"Subscription CB\"] = r\n", + " elif not is_sub and \"Timer CB\" not in legend_entries:\n", + " legend_entries[\"Timer CB\"] = r\n", + "\n", + " if next_cb_start is not None:\n", + " r_x = t_cb - t_start + d_cb - margin_x / 2\n", + " r_y = y + .5 - arr_width/2\n", + " r_w = next_cb_start - (t_cb + d_cb) + margin_x\n", + " r_h = arr_width\n", + " r = mpatch.Rectangle((r_x, r_y), r_w, r_h, color=\"orange\")\n", + " ax.add_artist(r)\n", + "\n", + " if is_sub:\n", + " node = inst.callback_obj.owner.subscription.node\n", + " else:\n", + " node = inst.callback_obj.owner.node\n", + " text = node.path\n", + "\n", + " text = trunc_chars(text, r_w, t_e2e)\n", + " if text:\n", + " ax.text(r_x + r_w / 2, r_y + r_h / 2, text, ha=\"center\", va=\"center\", backgroundcolor=(1,1,1,.5), zorder=11000)\n", + "\n", + " if \"Idle\" not in legend_entries:\n", + " legend_entries[\"Idle\"] = r\n", + "\n", + " next_cb_start = t_cb\n", + " case TrPublishInstance(timestamp=t_pub, publisher=pub):\n", + " if not subtree.deps:\n", + " y_labels.append(pub.topic_name if pub else None)\n", + "\n", + " scatter = ax.scatter(t_pub - t_start, y+.5, color=\"cyan\", marker=\".\", zorder=10000)\n", + "\n", + " if \"Publication\" not in legend_entries:\n", + " legend_entries[\"Publication\"] = scatter\n", + "\n", + " if next_cb_start is not None:\n", + " r_x = t_pub - t_start\n", + " r_y = y + .5 - arr_width/2\n", + " r_w = max(next_cb_start - t_pub + margin_x / 2, 0)\n", + " r = mpatch.Rectangle((r_x, r_y), r_w, arr_width, color=\"lightgreen\")\n", + " ax.add_artist(r)\n", + " if pub:\n", + " text = pub.topic_name\n", + " text = trunc_chars(text, r_w, t_e2e)\n", + " if text:\n", + " ax.text(r_x + r_w / 2, r_y + arr_width / 2, text, ha=\"center\", va=\"center\", backgroundcolor=(1,1,1,.5), zorder=11000)\n", + "\n", + " if \"DDS\" not in legend_entries:\n", + " legend_entries[\"DDS\"] = r\n", + "\n", + " topic_stats = msg_sizes.get(pub.topic_name)\n", + " if topic_stats:\n", + " size_str = bw_str(topic_stats)\n", + " ax.text(r_x + r_w / 2, r_y + arr_width + margin_y, size_str, ha=\"center\", backgroundcolor=(1,1,1,.5), zorder=11000)\n", + " else:\n", + " print(\"[WARN] Tried to publish to another PublishInstance\")\n", + " next_cb_start = None\n", + "\n", + " acc_fanout = 0\n", + " for dep in subtree.deps:\n", + " acc_fanout += plot_subtree(dep, ax, y_labels, y + acc_fanout, next_cb_start)\n", + " return height\n", + "\n", + "\n", + "\n", + "fig, ax = plt.subplots(figsize=(36, 20), num=\"path_viz\")\n", + "ax.set_ylim(0, len(e2es))\n", + "\n", + "y_labels = []\n", + "plot_subtree(tree, ax, y_labels)\n", + "\n", + "tree_e2e = e2e_lat(tree)\n", + "plot_margin_x = .01 * tree_e2e\n", + "ax.set_xlim(-tree_e2e - plot_margin_x, plot_margin_x)\n", + "ax.set_yticks(np.array(range(len(y_labels))) + .5, y_labels)\n", + "ax.set_title(f\"Timestep {i}: {(tree.head.timestamp - trees[0].head.timestamp):10.6f}s\")\n", + "ax.set_xlabel(\"Time relative to output message [s]\")\n", + "ax.set_ylabel(\"Start topic\")\n", + "\n", + "labels, handles = list(zip(*legend_entries.items()))\n", + "ax.legend(handles, labels)\n", + "print(len(y_labels))" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, { "cell_type": "code", "execution_count": null, diff --git a/tracing_interop/tr_types.py b/tracing_interop/tr_types.py index e14d3d6..0f541d8 100644 --- a/tracing_interop/tr_types.py +++ b/tracing_interop/tr_types.py @@ -1,100 +1,108 @@ +from collections import namedtuple, UserList from dataclasses import dataclass, field -from functools import cached_property -from typing import List, Dict - -import pandas as pd -from tqdm.notebook import tqdm +from typing import List, Dict, Optional, Set, TypeVar, Generic, Iterable +import bisect from tracetools_analysis.processor.ros2 import Ros2Handler -from tracetools_analysis.utils.ros2 import Ros2DataModelUtil -from .utils import list_to_dict, df_to_type_list +from .utils import df_to_type_list + +IdxItemType = TypeVar("IdxItemType") +Timestamp = namedtuple("Timestamp", ["timestamp"]) + + +class Index(Generic[IdxItemType]): + def __init__(self, items: Iterable[IdxItemType], **idx_fields): + sort_key = lambda item: item.timestamp + + self.__items = list(items) + self.__items.sort(key=sort_key) + self.__indices = {} + + for idx_name, is_multi in idx_fields.items(): + index = {} + self.__indices[idx_name] = index + + if is_multi: + for item in self.__items: + key = getattr(item, idx_name) + if key not in index: + index[key] = [] + index[key].append(item) # Also sorted since items are processed in order and only filtered here + else: + for item in self.__items: + key = getattr(item, idx_name) + if key in index: + print(repr(ValueError(f"Duplicate key: {idx_name}={key}; old={index[key]}; new={item}"))) + index[key] = item + + def __iter__(self): + return iter(self.__items) + + def __len__(self): + return len(self.__items) + + def __getattr__(self, item: str): + if not item.startswith("by_"): + return AttributeError( + f"Not found in index: '{item}'. Index lookups must be of the shape 'by_'.") + + return self.__indices[item.removeprefix("by_")] + + def __getstate__(self): + return vars(self) + + def __setstate__(self, state): + vars(self).update(state) @dataclass class TrContext: - nodes: Dict[int, 'TrNode'] - publishers: Dict[int, 'TrPublisher'] - subscriptions: Dict[int, 'TrSubscription'] - timers: Dict[int, 'TrTimer'] - timer_node_links: Dict[int, 'TrTimerNodeLink'] - subscription_objects: Dict[int, 'TrSubscriptionObject'] - callback_objects: Dict[int, 'TrCallbackObject'] - callback_symbols: Dict[int, 'TrCallbackSymbol'] - publish_instances: List['TrPublishInstance'] - callback_instances: List['TrCallbackInstance'] - topics: Dict[str, 'TrTopic'] - - util: Ros2DataModelUtil | None - handler: Ros2Handler | None - - def __init__(self, util: Ros2DataModelUtil, handler: Ros2Handler): - self.util = util - self.handler = handler + nodes: Index['TrNode'] + publishers: Index['TrPublisher'] + subscriptions: Index['TrSubscription'] + timers: Index['TrTimer'] + timer_node_links: Index['TrTimerNodeLink'] + subscription_objects: Index['TrSubscriptionObject'] + callback_objects: Index['TrCallbackObject'] + callback_symbols: Index['TrCallbackSymbol'] + publish_instances: Index['TrPublishInstance'] + callback_instances: Index['TrCallbackInstance'] + topics: Index['TrTopic'] + def __init__(self, handler: Ros2Handler): print("[TrContext] Processing ROS 2 objects from traces...") - self.nodes = list_to_dict(df_to_type_list(handler.data.nodes, TrNode, _c=self)) - print(f" ├─ Processed {len(self.nodes):<8d} nodes") - self.publishers = list_to_dict(df_to_type_list(handler.data.rcl_publishers, TrPublisher, _c=self)) - print(f" ├─ Processed {len(self.publishers):<8d} publishers") - self.subscriptions = list_to_dict(df_to_type_list(handler.data.rcl_subscriptions, TrSubscription, _c=self)) - print(f" ├─ Processed {len(self.subscriptions):<8d} subscriptions") - self.timers = list_to_dict(df_to_type_list(handler.data.timers, TrTimer, _c=self)) - print(f" ├─ Processed {len(self.timers):<8d} timers") - self.timer_node_links = list_to_dict(df_to_type_list(handler.data.timer_node_links, TrTimerNodeLink)) - print(f" ├─ Processed {len(self.timer_node_links):<8d} timer-node links") - self.subscription_objects = list_to_dict( - df_to_type_list(handler.data.subscription_objects, TrSubscriptionObject, _c=self)) - print(f" ├─ Processed {len(self.subscription_objects):<8d} subscription objects") - self.callback_objects = list_to_dict(df_to_type_list(handler.data.callback_objects, TrCallbackObject, _c=self)) - print(f" ├─ Processed {len(self.callback_objects):<8d} callback objects") - self.callback_symbols = list_to_dict(df_to_type_list(handler.data.callback_symbols, TrCallbackSymbol, _c=self)) - print(f" ├─ Processed {len(self.callback_symbols):<8d} callback symbols") - self.publish_instances = df_to_type_list(handler.data.rcl_publish_instances, TrPublishInstance, _c=self) - print(f" ├─ Processed {len(self.publish_instances):<8d} publish instances") - self.callback_instances = df_to_type_list(handler.data.callback_instances, TrCallbackInstance, _c=self) - print(f" ├─ Processed {len(self.callback_instances):<8d} callback instances") + self.nodes = Index(df_to_type_list(handler.data.nodes, TrNode, _c=self), + id=False) + self.publishers = Index(df_to_type_list(handler.data.rcl_publishers, TrPublisher, _c=self), + id=False, node_handle=True, topic_name=True) + self.subscriptions = Index(df_to_type_list(handler.data.rcl_subscriptions, TrSubscription, _c=self), + id=False, node_handle=True, topic_name=True) + self.timers = Index(df_to_type_list(handler.data.timers, TrTimer, _c=self), + id=False) + self.timer_node_links = Index(df_to_type_list(handler.data.timer_node_links, TrTimerNodeLink), + id=False, node_handle=True) + self.subscription_objects = Index( + df_to_type_list(handler.data.subscription_objects, TrSubscriptionObject, _c=self), + id=False, subscription_handle=False) + self.callback_objects = Index(df_to_type_list(handler.data.callback_objects, TrCallbackObject, _c=self), + id=False, callback_object=False) + self.callback_symbols = Index(df_to_type_list(handler.data.callback_symbols, TrCallbackSymbol, _c=self), + id=False) + self.publish_instances = Index(df_to_type_list(handler.data.rcl_publish_instances, TrPublishInstance, _c=self, + mappers={"timestamp": lambda t: t * 1e-9}), + publisher_handle=True) + self.callback_instances = Index(df_to_type_list(handler.data.callback_instances, TrCallbackInstance, _c=self, + mappers={"timestamp": lambda t: t.timestamp(), + "duration": lambda d: d.total_seconds()}), + callback_object=True) - _unique_topic_names = {*(pub.topic_name for pub in self.publishers.values()), - *(sub.topic_name for sub in self.subscriptions.values())} - self.topics = list_to_dict(map(lambda name: TrTopic(name=name, _c=self), _unique_topic_names), key="name") - print(f" └─ Processed {len(self.topics):<8d} topics\n") + _unique_topic_names = {*(pub.topic_name for pub in self.publishers), + *(sub.topic_name for sub in self.subscriptions)} - print("[TrContext] Caching dynamic properties...") - - p = tqdm(desc=" ├─ Processing nodes", total=len(self.nodes.values())) - [(o.path, o.publishers, o.subscriptions, o.timers, p.update()) for o in self.nodes.values()] - print(" ├─ Cached node properties") - p = tqdm(desc=" ├─ Processing publishers", total=len(self.publishers.values())) - [(o.instances, o.subscriptions, p.update()) for o in self.publishers.values()] - print(" ├─ Cached publisher properties") - p = tqdm(desc=" ├─ Processing subscriptions", total=len(self.subscriptions.values())) - [(o.publishers, o.subscription_objects, p.update()) for o in self.subscriptions.values()] - print(" ├─ Cached subscription properties") - p = tqdm(desc=" ├─ Processing timers", total=len(self.timers.values())) - [(o.nodes, p.update()) for o in self.timers.values()] - print(" ├─ Cached timer properties") - p = tqdm(desc=" ├─ Processing CB objects", total=len(self.callback_objects.values())) - [(o.callback_instances, o.owner, p.update()) for o in self.callback_objects.values()] - print(" ├─ Cached callback object properties") - p = tqdm(desc=" ├─ Processing CB symbols", total=len(self.callback_symbols.values())) - [(o.callback_objs, p.update()) for o in self.callback_symbols.values()] - print(" ├─ Cached callback symbol properties") - p = tqdm(desc=" ├─ Processing topics", total=len(self.topics.values())) - [(o.publishers, o.subscriptions, p.update()) for o in self.topics.values()] - print(" └─ Cached topic properties\n") - - def __getstate__(self): - state = self.__dict__.copy() - del state["util"] - del state["handler"] - return state - - def __setstate__(self, state): - self.__dict__.update(state) - self.util = None - self.handler = None + self.topics = Index((TrTopic(name=name, _c=self) for name in _unique_topic_names), + name=False) def __repr__(self): return f"TrContext" @@ -110,26 +118,30 @@ class TrNode: namespace: str _c: TrContext = field(repr=False) - @cached_property + @property def path(self) -> str: return '/'.join((self.namespace, self.name)).replace('//', '/') - @cached_property + @property def publishers(self) -> List['TrPublisher']: - return list(filter(lambda pub: pub.node_handle == self.id, self._c.publishers.values())) + return self._c.publishers.by_node_handle.get(self.id) or [] - @cached_property + @property def subscriptions(self) -> List['TrSubscription']: - return list(filter(lambda sub: sub.node_handle == self.id, self._c.subscriptions.values())) + return self._c.subscriptions.by_node_handle.get(self.id) or [] - @cached_property + @property def timers(self) -> List['TrTimer']: - links = [link.id for link in self._c.timer_node_links.values() if link.node_handle == self.id] - return list(filter(lambda timer: timer.id in links, self._c.timers.values())) + links = self._c.timer_node_links.by_node_handle.get(self.id) or [] + timers = [self._c.timers.by_id.get(link.id) for link in links] + return [t for t in timers if t is not None] def __hash__(self): return hash(self.id) + def __eq__(self, other): + return self.__hash__() == other.__hash__() + @dataclass class TrPublisher: @@ -142,24 +154,27 @@ class TrPublisher: _c: TrContext = field(repr=False) @property - def node(self) -> 'TrNode': - return self._c.nodes[self.node_handle] - - @cached_property - def subscriptions(self) -> List['TrSubscription']: - return list(filter(lambda sub: sub.topic_name == self.topic_name, self._c.subscriptions.values())) - - @cached_property - def instances(self) -> List['TrPublishInstance']: - return list(filter(lambda inst: inst.publisher_handle == self.id, self._c.publish_instances)) + def node(self) -> Optional['TrNode']: + return self._c.nodes.by_id.get(self.node_handle) @property - def topic(self) -> 'TrTopic': - return self._c.topics[self.topic_name] + def subscriptions(self) -> List['TrSubscription']: + return self._c.subscriptions.by_topic_name.get(self.topic_name) or [] + + @property + def instances(self) -> List['TrPublishInstance']: + return self._c.publish_instances.by_publisher_handle.get(self.id) or [] + + @property + def topic(self) -> Optional['TrTopic']: + return self._c.topics.by_name.get(self.topic_name) def __hash__(self): return hash(self.id) + def __eq__(self, other): + return self.__hash__() == other.__hash__() + @dataclass class TrSubscription: @@ -172,25 +187,27 @@ class TrSubscription: _c: TrContext = field(repr=False) @property - def node(self) -> 'TrNode': - return self._c.nodes[self.node_handle] - - @cached_property - def publishers(self) -> List['TrPublisher']: - return list(filter(lambda pub: pub.topic_name == self.topic_name, self._c.publishers.values())) - - @cached_property - def subscription_objects(self) -> List['TrSubscriptionObject']: - return list( - filter(lambda sub_obj: sub_obj.subscription_handle == self.id, self._c.subscription_objects.values())) + def node(self) -> Optional['TrNode']: + return self._c.nodes.by_id.get(self.node_handle) @property - def topic(self) -> 'TrTopic': - return self._c.topics[self.topic_name] + def publishers(self) -> List['TrPublisher']: + return self._c.publishers.by_topic_name.get(self.topic_name) or [] + + @property + def subscription_object(self) -> Optional['TrSubscriptionObject']: + return self._c.subscription_objects.by_subscription_handle.get(self.id) + + @property + def topic(self) -> Optional['TrTopic']: + return self._c.topics.by_name.get(self.topic_name) def __hash__(self): return hash(self.id) + def __eq__(self, other): + return self.__hash__() == other.__hash__() + @dataclass class TrTimer: @@ -200,18 +217,23 @@ class TrTimer: tid: int _c: TrContext = field(repr=False) - @cached_property - def nodes(self) -> List['TrNode']: - links = [link.node_handle for link in self._c.timer_node_links.values() if link.id == self.id] - return list(filter(lambda node: node.id in links, self._c.nodes.values())) + @property + def node(self) -> Optional['TrNode']: + link = self._c.timer_node_links.by_id.get(self.id) + if link is None: + return None + return self._c.nodes.by_id.get(link.node_handle) @property - def callback_object(self) -> 'TrCallbackObject': - return self._c.callback_objects[self.id] + def callback_object(self) -> Optional['TrCallbackObject']: + return self._c.callback_objects.by_id.get(self.id) def __hash__(self): return hash(self.id) + def __eq__(self, other): + return self.__hash__() == other.__hash__() + @dataclass class TrTimerNodeLink: @@ -219,25 +241,34 @@ class TrTimerNodeLink: timestamp: int node_handle: int + def __hash__(self): + return hash((self.id, self.node_handle)) + + def __eq__(self, other): + return self.__hash__() == other.__hash__() + @dataclass class TrSubscriptionObject: - id: int # subscription + id: int timestamp: int subscription_handle: int _c: TrContext = field(repr=False) @property - def subscription(self) -> 'TrSubscription': - return self._c.subscriptions[self.subscription_handle] + def subscription(self) -> Optional['TrSubscription']: + return self._c.subscriptions.by_id.get(self.subscription_handle) @property - def callback_object(self) -> 'TrCallbackObject': - return self._c.callback_objects[self.id] + def callback_object(self) -> Optional['TrCallbackObject']: + return self._c.callback_objects.by_id.get(self.id) def __hash__(self): return hash((self.id, self.timestamp, self.subscription_handle)) + def __eq__(self, other): + return self.__hash__() == other.__hash__() + @dataclass class TrCallbackObject: @@ -246,62 +277,67 @@ class TrCallbackObject: callback_object: int _c: TrContext = field(repr=False) - @cached_property + @property def callback_instances(self) -> List['TrCallbackInstance']: - return list(filter(lambda inst: inst.callback_object == self.callback_object, self._c.callback_instances)) + return self._c.callback_instances.by_callback_object.get(self.callback_object) or [] @property - def callback_symbol(self) -> 'TrCallbackSymbol': - return self._c.callback_symbols[self.id] + def callback_symbol(self) -> Optional['TrCallbackSymbol']: + return self._c.callback_symbols.by_id.get(self.callback_object) - @cached_property + @property def owner(self): - if self.id in self._c.timers: - return self._c.timers[self.id] - if self.id in self._c.publishers: - return self._c.publishers[self.id] - if self.id in self._c.subscription_objects: - return self._c.subscription_objects[self.id] - if self.id in self._c.handler.data.services.index: - return 'Service' - if self.id in self._c.handler.data.clients.index: - return 'Client' + if self.id in self._c.timers.by_id: + return self._c.timers.by_id[self.id] + if self.id in self._c.publishers.by_id: + return self._c.publishers.by_id[self.id] + if self.id in self._c.subscription_objects.by_id: + return self._c.subscription_objects.by_id[self.id] return None def __hash__(self): return hash((self.id, self.timestamp, self.callback_object)) + def __eq__(self, other): + return self.__hash__() == other.__hash__() + @dataclass class TrPublishInstance: publisher_handle: int - timestamp: int + timestamp: float message: int _c: TrContext = field(repr=False) @property - def publisher(self) -> 'TrPublisher': - return self._c.publishers[self.publisher_handle] + def publisher(self) -> Optional['TrPublisher']: + return self._c.publishers.by_id.get(self.publisher_handle) def __hash__(self): return hash((self.publisher_handle, self.timestamp, self.message)) + def __eq__(self, other): + return self.__hash__() == other.__hash__() + @dataclass class TrCallbackInstance: callback_object: int - timestamp: pd.Timestamp - duration: pd.Timedelta + timestamp: float + duration: float intra_process: bool _c: TrContext = field(repr=False) @property - def callback_obj(self) -> 'TrCallbackObject': - return self._c.callback_objects[self.callback_object] + def callback_obj(self) -> Optional['TrCallbackObject']: + return self._c.callback_objects.by_callback_object.get(self.callback_object) def __hash__(self): return hash((self.callback_object, self.timestamp, self.duration)) + def __eq__(self, other): + return self.__hash__() == other.__hash__() + @dataclass class TrCallbackSymbol: @@ -310,13 +346,16 @@ class TrCallbackSymbol: symbol: str _c: TrContext = field(repr=False) - @cached_property - def callback_objs(self) -> List['TrCallbackObject']: - return list(filter(lambda cb_obj: cb_obj.callback_object == self.id, self._c.callback_objects.values())) + @property + def callback_obj(self) -> Optional['TrCallbackObject']: + return self._c.callback_objects.by_callback_object.get(self.id) def __hash__(self): return hash((self.id, self.timestamp, self.symbol)) + def __eq__(self, other): + return self.__hash__() == other.__hash__() + ####################################### # Self-defined (not from ROS2DataModel) @@ -326,14 +365,18 @@ class TrCallbackSymbol: class TrTopic: name: str _c: TrContext = field(repr=False) + timestamp: int = 0 - @cached_property + @property def publishers(self) -> List['TrPublisher']: - return list(filter(lambda pub: pub.topic_name == self.name, self._c.publishers.values())) + return self._c.publishers.by_topic_name.get(self.name) or [] - @cached_property + @property def subscriptions(self) -> List['TrSubscription']: - return list(filter(lambda sub: sub.topic_name == self.name, self._c.subscriptions.values())) + return self._c.subscriptions.by_topic_name.get(self.name) or [] def __hash__(self): return hash(self.name) + + def __eq__(self, other): + return self.__hash__() == other.__hash__() diff --git a/tracing_interop/utils.py b/tracing_interop/utils.py index 7c91a4b..7249e49 100644 --- a/tracing_interop/utils.py +++ b/tracing_interop/utils.py @@ -8,27 +8,18 @@ def row_to_type(row, type, **type_kwargs): return type(**row, **type_kwargs) -def df_to_type_list(df, type, **type_kwargs): +def df_to_type_list(df, type, mappers=None, **type_kwargs): + if mappers is not None: + for col, mapper in mappers.items(): + df[col] = df[col].map(mapper) + has_idx = not isinstance(df.index, pd.RangeIndex) ret_list = [] - p = tqdm(desc=" ├─ Processing", total=len(df)) - for row in df.itertuples(index=has_idx): - p.update() + i=0 + for row in tqdm(df.itertuples(index=has_idx), desc=f" ├─ Processing {type.__name__}s", total=len(df)): row_dict = row._asdict() if has_idx: row_dict["id"] = row.Index del row_dict["Index"] ret_list.append(row_to_type(row_dict, type, **type_kwargs)) return ret_list - - -def by_index(df, index, type): - return df_to_type_list(df.loc[index], type) - - -def by_column(df, column_name, column_val, type): - return df_to_type_list(df[df[column_name] == column_val], type) - - -def list_to_dict(ls, key='id'): - return {getattr(item, key): item for item in ls}