From 0261b8200f3e28f292fa6e92f19ca1fbcbf2e538 Mon Sep 17 00:00:00 2001 From: Maximilian Schmeller Date: Tue, 9 Aug 2022 18:36:40 +0200 Subject: [PATCH] Hierarchical latency graph, bugfixes, renamed types.py to not interfere with other python packages. --- clang_interop/{types.py => cl_types.py} | 94 ++-- clang_interop/process_clang_output.py | 162 +++---- latency_graph/__init__.py | 0 latency_graph/latency_graph.py | 329 ++++++++++++++ matching/__init__.py | 0 matching/subscriptions.py | 518 ++++++++++++++++++++++ misc/utils.py | 12 + requirements.txt | 1 + tracing_interop/{types.py => tr_types.py} | 60 +-- tracing_interop/utils.py | 16 +- 10 files changed, 1022 insertions(+), 170 deletions(-) rename clang_interop/{types.py => cl_types.py} (64%) create mode 100644 latency_graph/__init__.py create mode 100644 latency_graph/latency_graph.py create mode 100644 matching/__init__.py create mode 100644 matching/subscriptions.py rename tracing_interop/{types.py => tr_types.py} (84%) diff --git a/clang_interop/types.py b/clang_interop/cl_types.py similarity index 64% rename from clang_interop/types.py rename to clang_interop/cl_types.py index 5eecc79..d76c264 100644 --- a/clang_interop/types.py +++ b/clang_interop/cl_types.py @@ -1,24 +1,36 @@ import os +import re from dataclasses import dataclass, field from typing import List, Literal, Dict, Set @dataclass class ClTranslationUnit: - dependencies: Dict[int, Set[int]] - publications: Dict[int, Set[int]] - nodes: Dict[int, 'ClNode'] - publishers: Dict[int, 'ClPublisher'] - subscriptions: Dict[int, 'ClSubscription'] - timers: Dict[int, 'ClTimer'] - fields: Dict[int, 'ClField'] - methods: Dict[int, 'ClMethod'] - accesses: List['ClMemberRef'] + filename: str + + def __hash__(self): + return hash(self.filename) @dataclass class ClContext: - translation_units: Dict[str, 'ClTranslationUnit'] = field(default_factory=dict) + translation_units: Set['ClTranslationUnit'] + + nodes: Set['ClNode'] + publishers: Set['ClPublisher'] + subscriptions: Set['ClSubscription'] + timers: Set['ClTimer'] + + fields: Set['ClField'] + methods: Set['ClMethod'] + + accesses: List['ClMemberRef'] + + dependencies: Dict['ClMethod', Set['ClMethod']] + publications: Dict['ClMethod', Set['ClPublisher']] + + def __repr__(self): + return f"ClContext({len(self.translation_units)} TUs)" @dataclass @@ -50,15 +62,17 @@ class ClSourceRange: @dataclass class ClNode: + tu: 'ClTranslationUnit' = field(repr=False) id: int qualified_name: str - source_range: 'ClSourceRange' + source_range: 'ClSourceRange' = field(repr=False) field_ids: List[int] | None method_ids: List[int] | None ros_name: str | None ros_namespace: str | None - def __init__(self, json_obj): + def __init__(self, json_obj, tu): + self.tu = tu self.id = json_obj['id'] self.qualified_name = json_obj['qualified_name'] self.source_range = ClSourceRange(json_obj['source_range']) @@ -68,23 +82,43 @@ class ClNode: self.ros_namespace = json_obj['ros_namespace'] if 'ros_namespace' in json_obj else None def __hash__(self): - return hash(self.id) + return hash((self.tu, self.id)) @dataclass class ClMethod: + tu: 'ClTranslationUnit' = field(repr=False) id: int qualified_name: str - source_range: 'ClSourceRange' + source_range: 'ClSourceRange' = field(repr=False) return_type: str | None parameter_types: List[str] | None + is_lambda: bool | None - def __init__(self, json_obj): + @property + def signature(self): + # Lambda definitions end in this suffix + class_name = self.qualified_name.removesuffix("::(anonymous class)::operator()") + + # If the definition is no lambda (and hence no suffix has been removed), the last part after :: is the method + # name. Remove it to get the class name. + if class_name == self.qualified_name: + class_name = "::".join(class_name.split("::")[:-1]) + + if self.is_lambda: + return f"{class_name}$lambda" + + param_str = ','.join(self.parameter_types) if self.parameter_types is not None else '' + return f"{self.return_type if self.return_type else ''} ({class_name})({param_str})" + + def __init__(self, json_obj, tu): + self.tu = tu self.id = json_obj['id'] self.qualified_name = json_obj['qualified_name'] self.source_range = ClSourceRange(json_obj['source_range']) self.return_type = json_obj['signature']['return_type'] if 'signature' in json_obj else None self.parameter_types = json_obj['signature']['parameter_types'] if 'signature' in json_obj else None + self.is_lambda = json_obj['is_lambda'] if 'is_lambda' in json_obj else None def __hash__(self): return hash(self.id) @@ -92,11 +126,13 @@ class ClMethod: @dataclass class ClField: + tu: 'ClTranslationUnit' = field(repr=False) id: int qualified_name: str - source_range: 'ClSourceRange' + source_range: 'ClSourceRange' = field(repr=False) - def __init__(self, json_obj): + def __init__(self, json_obj, tu): + self.tu = tu self.id = json_obj['id'] self.qualified_name = json_obj['qualified_name'] self.source_range = ClSourceRange(json_obj['source_range']) @@ -107,13 +143,15 @@ class ClField: @dataclass class ClMemberRef: + tu: 'ClTranslationUnit' = field(repr=False) type: Literal["read", "write", "call", "arg", "pub"] | None member_chain: List[int] method_id: int | None node_id: int | None - source_range: 'ClSourceRange' + source_range: 'ClSourceRange' = field(repr=False) - def __init__(self, json_obj): + def __init__(self, json_obj, tu): + self.tu = tu access_type = json_obj['context']['access_type'] if access_type == 'none': access_type = None @@ -129,11 +167,13 @@ class ClMemberRef: @dataclass class ClSubscription: + tu: 'ClTranslationUnit' = field(repr=False) topic: str | None callback_id: int | None - source_range: 'ClSourceRange' + source_range: 'ClSourceRange' = field(repr=False) - def __init__(self, json_obj): + def __init__(self, json_obj, tu): + self.tu = tu self.topic = json_obj['topic'] if 'topic' in json_obj else None self.callback_id = json_obj['callback']['id'] if 'callback' in json_obj else None self.source_range = ClSourceRange(json_obj['source_range']) @@ -144,14 +184,16 @@ class ClSubscription: @dataclass class ClPublisher: + tu: 'ClTranslationUnit' = field(repr=False) topic: str | None member_id: int | None - source_range: 'ClSourceRange' + source_range: 'ClSourceRange' = field(repr=False) def update(self, t2: 'ClTimer'): return self - def __init__(self, json_obj): + def __init__(self, json_obj, tu): + self.tu = tu self.topic = json_obj['topic'] if 'topic' in json_obj else None self.member_id = json_obj['member']['id'] if 'member' in json_obj else None self.source_range = ClSourceRange(json_obj['source_range']) @@ -162,10 +204,12 @@ class ClPublisher: @dataclass class ClTimer: + tu: 'ClTranslationUnit' = field(repr=False) callback_id: int | None - source_range: 'ClSourceRange' + source_range: 'ClSourceRange' = field(repr=False) - def __init__(self, json_obj): + def __init__(self, json_obj, tu): + self.tu = tu self.callback_id = json_obj['callback']['id'] if 'callback' in json_obj else None self.source_range = ClSourceRange(json_obj['source_range']) diff --git a/clang_interop/process_clang_output.py b/clang_interop/process_clang_output.py index d554b51..abea652 100644 --- a/clang_interop/process_clang_output.py +++ b/clang_interop/process_clang_output.py @@ -2,14 +2,12 @@ import functools import json import os import pickle -import re -from typing import Tuple, Iterable +from typing import Iterable import numpy as np import pandas as pd -import termcolor -from clang_interop.types import ClNode, ClField, ClTimer, ClMethod, ClPublisher, ClSubscription, ClMemberRef, ClContext, \ +from clang_interop.cl_types import ClNode, ClField, ClTimer, ClMethod, ClPublisher, ClSubscription, ClMemberRef, ClContext, \ ClTranslationUnit IN_DIR = "/home/max/Projects/ma-ros2-internal-dependency-analyzer/output" @@ -123,14 +121,14 @@ def dedup(elems): ret_list.append(elem) print(f"Fused {len(elems)} {type(elem)}s") - return ret_list + return set(ret_list) def dictify(elems, key='id'): return {getattr(e, key): e for e in elems} -def definitions_from_json(cb_dict): +def definitions_from_json(cb_dict, tu): nodes = [] pubs = [] subs = [] @@ -141,145 +139,85 @@ def definitions_from_json(cb_dict): if "nodes" in cb_dict: for node in cb_dict["nodes"]: - nodes.append(ClNode(node)) + nodes.append(ClNode(node, tu)) for field in node["fields"]: - fields.append(ClField(field)) + fields.append(ClField(field, tu)) for method in node["methods"]: - methods.append(ClMethod(method)) + methods.append(ClMethod(method, tu)) if "publishers" in cb_dict: for publisher in cb_dict["publishers"]: - pubs.append(ClPublisher(publisher)) + pubs.append(ClPublisher(publisher, tu)) if "subscriptions" in cb_dict: for subscription in cb_dict["subscriptions"]: - subs.append(ClSubscription(subscription)) + subs.append(ClSubscription(subscription, tu)) + if "callback" in subscription: + methods.append(ClMethod(subscription["callback"], tu)) if "timers" in cb_dict: for timer in cb_dict["timers"]: - timers.append(ClTimer(timer)) + timers.append(ClTimer(timer, tu)) + if "callback" in timer: + methods.append(ClMethod(timer["callback"], tu)) if "accesses" in cb_dict: for access_type in cb_dict["accesses"]: for access in cb_dict["accesses"][access_type]: - accesses.append(ClMemberRef(access)) + accesses.append(ClMemberRef(access, tu)) + if "method" in access["context"]: + methods.append(ClMethod(access["context"]["method"], tu)) - nodes = dictify(dedup(nodes)) - pubs = dictify(dedup(pubs), key='member_id') - subs = dictify(dedup(subs), key='callback_id') - timers = dictify(dedup(timers), key='callback_id') - fields = dictify(dedup(fields)) - methods = dictify(dedup(methods)) + nodes = dedup(nodes) + pubs = dedup(pubs) + subs = dedup(subs) + timers = dedup(timers) + fields = dedup(fields) + methods = dedup(methods) return nodes, pubs, subs, timers, fields, methods, accesses -def highlight(substr: str, text: str): - regex = r"(?<=\W)({substr})(?=\W)|^({substr})$" - return re.sub(regex.format(substr=substr), termcolor.colored(r"\1\2", 'magenta', attrs=['bold']), text) - - -def prompt_user(file: str, cb: str, idf: str, text: str) -> Tuple[str, bool, bool]: - print('\n' * 5) - print(f"{file.rstrip('.cpp').rstrip('.hpp')}\n->{cb}:") - print(highlight(idf.split('::')[-1], text)) - answer = input(f"{highlight(idf, idf)}\n" - f"write (w), read (r), both (rw), ignore future (i) exit and save (q), undo (z), skip (Enter): ") - if answer not in ["", "r", "w", "rw", "q", "z", "i"]: - print(f"Invalid answer '{answer}', try again.") - answer = prompt_user(file, cb, idf, text) - - if answer == 'i': - ignored_idfs.add(idf) - elif any(x in answer for x in ['r', 'w']): - ignored_idfs.discard(idf) - - return answer, answer == "q", answer == "z" - - -def main(cbs): - open_files = {} - cb_rw_dict = {} - - jobs = [] - - for cb_id, cb_dict in cbs.items(): - cb_rw_dict[cb_dict['qualified_name']] = {'reads': set(), 'writes': set()} - for ref_dict in cb_dict['member_refs']: - if ref_dict['file'] not in open_files: - with open(ref_dict['file'], 'r') as f: - open_files[ref_dict['file']] = f.readlines() - - ln = ref_dict['start_line'] - 1 - text = open_files[ref_dict['file']] - line = termcolor.colored(text[ln], None, "on_cyan") - lines = [*text[ln - 3:ln], line, *text[ln + 1:ln + 4]] - text = ''.join(lines) - jobs.append((ref_dict['file'], cb_dict['qualified_name'], ref_dict['qualified_name'], text)) - - i = 0 - do_undo = False - while i < len(jobs): - file, cb, idf, text = jobs[i] - - if do_undo: - ignored_idfs.discard(idf) - cb_rw_dict[cb]['reads'].discard(idf) - cb_rw_dict[cb]['writes'].discard(idf) - do_undo = False - - if idf in ignored_idfs: - print("Ignoring", idf) - i += 1 - continue - - if idf in cb_rw_dict[cb]['reads'] and idf in cb_rw_dict[cb]['writes']: - print(f"{idf} is already written to and read from in {cb}, skipping.") - i += 1 - continue - - classification, answ_quit, answ_undo = prompt_user(file, cb, idf, text) - - if answ_quit: - del cb_rw_dict[file][cb] - break - elif answ_undo: - i -= 1 - do_undo = True - continue - - if 'r' in classification: - cb_rw_dict[cb]['reads'].add(idf) - if 'w' in classification: - cb_rw_dict[cb]['writes'].add(idf) - if not any(x in classification for x in ['r', 'w']): - print(f"Ignoring occurences of {idf} in cb.") - - i += 1 - - with open("deps.json", "w") as f: - json.dump(cb_rw_dict, f, cls=SetEncoder) - - print("Done.") - - def process_clang_output(directory=IN_DIR): - clang_context = ClContext() + all_tus = set() + all_nodes = set() + all_pubs = set() + all_subs = set() + all_timers = set() + all_fields = set() + all_methods = set() + all_accesses = [] + all_deps = {} + all_publications = {} for filename in os.listdir(IN_DIR): source_filename = SRC_FILE_NAME(filename) print(f"Processing {source_filename}") + with open(os.path.join(IN_DIR, filename), "r") as f: cb_dict = json.load(f) if cb_dict is None: print(f" [WARN ] Empty tool output detected in {filename}") continue - nodes, pubs, subs, timers, fields, methods, accesses = definitions_from_json(cb_dict) + tu = ClTranslationUnit(source_filename) + all_tus.add(tu) + + nodes, pubs, subs, timers, fields, methods, accesses = definitions_from_json(cb_dict, tu) deps, publications = find_data_deps(accesses) - tu = ClTranslationUnit(deps, publications, nodes, pubs, subs, timers, fields, methods, accesses) - clang_context.translation_units[source_filename] = tu + all_nodes.update(nodes) + all_pubs.update(pubs) + all_subs.update(subs) + all_timers.update(timers) + all_fields.update(fields) + all_methods.update(methods) + all_accesses += accesses + all_deps.update(deps) + all_publications.update(publications) + + clang_context = ClContext(all_tus, all_nodes, all_pubs, all_subs, all_timers, all_fields, all_methods, all_accesses, + all_deps, all_publications) return clang_context diff --git a/latency_graph/__init__.py b/latency_graph/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/latency_graph/latency_graph.py b/latency_graph/latency_graph.py new file mode 100644 index 0000000..8ddae26 --- /dev/null +++ b/latency_graph/latency_graph.py @@ -0,0 +1,329 @@ +from dataclasses import dataclass +from itertools import combinations +from multiprocessing import Pool +from typing import Optional, Set, List, Iterable, Dict, Tuple + +from tqdm.notebook import tqdm +from tqdm.contrib import concurrent + +from matching.subscriptions import sanitize +from tracing_interop.tr_types import TrContext, TrCallbackObject, TrCallbackSymbol, TrNode, TrPublisher, TrSubscription, \ + TrTimer, TrPublishInstance, TrSubscriptionObject, TrTopic, TrCallbackInstance + + +TOPIC_FILTERS = ["/parameter_events", "/tf_static", "/robot_description", "diagnostics"] + + +def _map_cb_times(args): + cb_id, inst_times, pub_timestamps = args + pub_cb_overlaps = {i: set() for i in range(len(pub_timestamps))} + + inst_times.sort(key=lambda tup: tup[0]) # tup[0] is start time + + inst_iter = iter(inst_times) + pub_iter = iter(enumerate(pub_timestamps)) + + inst_start, inst_end = next(inst_iter, (None, None)) + i, t = next(pub_iter, (None, None)) + while inst_start is not None and i is not None: + if inst_start <= t <= inst_end: + pub_cb_overlaps[i].add(cb_id) + + if t <= inst_end: + i, t = next(pub_iter, (None, None)) + else: + inst_start, inst_end = next(inst_iter, (None, None)) + + return pub_cb_overlaps + + +def _get_cb_owner_node(cb: TrCallbackObject) -> TrNode | None: + match cb.owner: + case TrTimer(nodes=nodes): + owner_nodes = nodes + case TrSubscriptionObject(subscription=sub): + owner_nodes = [sub.node] + case _: + owner_nodes = [] + + if len(owner_nodes) > 1: + raise RuntimeError(f"CB has owners {', '.join(map(lambda n: n.path, owner_nodes))}") + elif not owner_nodes: + print("[WARN] CB has no owners") + return None + + return owner_nodes[0] + + +def _hierarchize(lg_nodes: Iterable['LGHierarchyLevel']): + base = LGHierarchyLevel(None, [], "", []) + + def _insert(parent, node, path): + match path: + case []: + parent.children.append(node) + node.parent = parent + case [head, *tail]: + next_node = next(iter(n for n in parent.children if n.name == head), None) + if next_node is None: + next_node = LGHierarchyLevel(parent, [], head, []) + parent.children.append(next_node) + _insert(next_node, node, tail) + + for node in lg_nodes: + path = node.name.strip("/").split("/") + node.name = path[-1] + _insert(base, node, path[:-1]) + + return base + + +def inst_runtime_interval(cb_inst: TrCallbackInstance): + inst_t_min = cb_inst.timestamp.timestamp() + inst_t_max = inst_t_min + cb_inst.duration.total_seconds() + return inst_t_min, inst_t_max + + +def _get_publishing_cbs(cbs: Set[TrCallbackObject], pub: TrPublisher): + """ + Counts number of publication instances that lie within one of the cb_intervals. + """ + pub_timestamps = [inst.timestamp * 1e-9 for inst in pub.instances] + + # Algorithm: Two-pointer method + # With both the pub_timestamps and cb_intervals sorted ascending, + # we can cut down the O(m*n) comparisons to O(m+n). + pub_timestamps.sort() + + cb_id_to_cb = {cb.id: cb for cb in cbs} + _map_args = [(cb.id, [inst_runtime_interval(inst) for inst in cb.callback_instances], pub_timestamps) for cb in cbs] + + with Pool() as p: + cb_wise_overlaps = p.map(_map_cb_times, _map_args) + + pub_cb_overlaps = {i: set() for i in range(len(pub_timestamps))} + for overlap_dict in cb_wise_overlaps: + for i, cb_ids in overlap_dict.items(): + cbs = [cb_id_to_cb[cb_id] for cb_id in cb_ids] + pub_cb_overlaps[i].update(cbs) + + pub_cbs = set() + cb_cb_overlaps = set() + for i, i_cbs in pub_cb_overlaps.items(): + if not i_cbs: + print(f"[WARN] Publication on {pub.topic_name} without corresponding callback!") + elif len(i_cbs) == 1: + pub_cbs.update(i_cbs) + else: # Multiple CBs in i_cbs + cb_cb_overlaps.update(iter(combinations(i_cbs, 2))) + + for cb1, cb2 in cb_cb_overlaps: + cb1_subset_of_cb2 = True + cb2_subset_of_cb1 = True + + for i_cbs in pub_cb_overlaps.values(): + if cb1 in i_cbs and cb2 not in i_cbs: + cb1_subset_of_cb2 = False + if cb2 in i_cbs and cb1 not in i_cbs: + cb2_subset_of_cb1 = False + + if cb1_subset_of_cb2 and cb2_subset_of_cb1: + print(f"[WARN] Callbacks {cb1.id} and {cb2.id} always run in parallel") + elif cb1_subset_of_cb2: + pub_cbs.discard(cb1) + elif cb2_subset_of_cb1: + pub_cbs.discard(cb2) + # else: discard none of them + + return pub_cbs + + +def _get_cb_topic_deps(nodes_to_cbs: Dict[TrNode, Set[TrCallbackObject]]): + cbs_subbed_to_topic: Dict[TrTopic, Set[TrCallbackObject]] = {} + + # Find topics the callback EXPLICITLY depends on + # - Timer callbacks: no EXPLICIT dependencies + # - Subscription callbacks: CB depends on the subscribed topic. Possibly also has other IMPLICIT dependencies + p = tqdm(desc="Processing CB subscriptions", total=sum(map(len, nodes_to_cbs.values()))) + for node, cbs in nodes_to_cbs.items(): + for cb in cbs: + p.update() + + if type(cb.owner) == TrSubscriptionObject: + dep_topics = [cb.owner.subscription.topic] + elif type(cb.owner) == TrTimer: + dep_topics = [] + elif cb.owner is None: + continue + else: + raise RuntimeError( + f"Callback owners other than timers/subscriptions cannot be handled: {cb.owner}") + + for topic in dep_topics: + if topic not in cbs_subbed_to_topic: + cbs_subbed_to_topic[topic] = set() + cbs_subbed_to_topic[topic].add(cb) + + # Find topics the callback publishes to (HEURISTICALLY!) + # For topics published to during the runtime of the callback's instances, + # assume that they are published by the callback + cbs_publishing_topic: Dict[TrTopic, Set[TrCallbackObject]] = {} + p = tqdm(desc="Processing node publications", total=len(nodes_to_cbs)) + for node, cbs in nodes_to_cbs.items(): + p.update() + if node is None: + continue + for pub in node.publishers: + if any(f in pub.topic_name for f in TOPIC_FILTERS): + continue + pub_cbs = _get_publishing_cbs(cbs, pub) + if pub.topic not in cbs_publishing_topic: + cbs_publishing_topic[pub.topic] = set() + + cbs_publishing_topic[pub.topic].update(pub_cbs) + + return cbs_subbed_to_topic, cbs_publishing_topic + + +@dataclass +class LGCallback: + name: str + in_topics: List[TrTopic] + out_topics: List[TrTopic] + + def id(self): + return self.name + + +@dataclass +class LGTrCallback(LGCallback): + cb: TrCallbackObject + sym: TrCallbackSymbol | None + node: TrNode | None + + def id(self): + return str(self.cb.id) + + +@dataclass +class LGHierarchyLevel: + parent: Optional['LGHierarchyLevel'] + children: List['LGHierarchyLevel'] + name: str + callbacks: List[LGCallback] + + @property + def full_name(self): + if self.parent is None: + return f"{self.name}" + + return f"{self.parent.full_name}/{self.name}" + + +@dataclass +class LGEdge: + start: LGCallback + end: LGCallback + + +@dataclass +class LatencyGraph: + top_node: LGHierarchyLevel + edges: List[LGEdge] + + def __init__(self, tr: TrContext): + ################################################## + # Annotate nodes with their callbacks + ################################################## + + # Note that nodes can also be None! + nodes_to_cbs = {} + p = tqdm(desc="Finding CB nodes", total=len(tr.callback_objects)) + for cb in tr.callback_objects.values(): + p.update() + node = _get_cb_owner_node(cb) + + if node not in nodes_to_cbs: + nodes_to_cbs[node] = set() + nodes_to_cbs[node].add(cb) + + ################################################## + # Find in/out topics for each callback + ################################################## + + cbs_subbed_to_topic, cbs_publishing_topic = _get_cb_topic_deps(nodes_to_cbs) + + ################################################## + # Map topics to their messages + ################################################## + + topics_to_messages = {} + p = tqdm(desc="Mapping messages to topics", total=len(tr.publish_instances)) + for pub_inst in tr.publish_instances: + p.update() + try: + topic = pub_inst.publisher.topic + except KeyError: + continue + + if topic not in topics_to_messages: + topics_to_messages[topic] = [] + topics_to_messages[topic].append(pub_inst) + + ################################################## + # Define nodes and edges on lowest level + ################################################## + + input = LGCallback("INPUT", [], [topic for topic in tr.topics.values() if not topic.publishers]) + output = LGCallback("OUTPUT", [topic for topic in tr.topics.values() if not topic.subscriptions], []) + + in_node = LGHierarchyLevel(None, [], "INPUT", [input]) + out_node = LGHierarchyLevel(None, [], "OUTPUT", [output]) + + lg_nodes = [in_node, out_node] + + tr_to_lg_cb = {} + + p = tqdm("Building graph nodes", total=sum(map(len, nodes_to_cbs.values()))) + for node, cbs in nodes_to_cbs.items(): + node_callbacks = [] + + for cb in cbs: + p.update() + try: + sym = cb.callback_symbol + pretty_sym = sanitize(sym.symbol) + except KeyError: + sym = None + pretty_sym = cb.id + in_topics = [topic for topic, cbs in cbs_subbed_to_topic.items() if cb in cbs] + out_topics = [topic for topic, cbs in cbs_publishing_topic.items() if cb in cbs] + lg_cb = LGTrCallback(pretty_sym, in_topics, out_topics, cb, sym, node) + node_callbacks.append(lg_cb) + tr_to_lg_cb[cb] = lg_cb + + lg_node = LGHierarchyLevel(None, [], node.path if node else "[NONE]", node_callbacks) + lg_nodes.append(lg_node) + + edges = [] + p = tqdm("Building graph edges", total=len(tr.topics)) + for topic in tr.topics.values(): + p.update() + sub_cbs = cbs_subbed_to_topic[topic] if topic in cbs_subbed_to_topic else [] + pub_cbs = cbs_publishing_topic[topic] if topic in cbs_publishing_topic else [] + + for sub_cb in sub_cbs: + for pub_cb in pub_cbs: + lg_edge = LGEdge(tr_to_lg_cb[pub_cb], tr_to_lg_cb[sub_cb]) + edges.append(lg_edge) + + self.edges = edges + + ################################################## + # Nodes into hierarchy levels + ################################################## + + self.top_node = _hierarchize(lg_nodes) + + def to_gv(self): + pass diff --git a/matching/__init__.py b/matching/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/matching/subscriptions.py b/matching/subscriptions.py new file mode 100644 index 0000000..13743ab --- /dev/null +++ b/matching/subscriptions.py @@ -0,0 +1,518 @@ +import pickle +import re +import sys +from dataclasses import dataclass, field +from enum import Enum +from typing import List, Optional, Iterable, Set, Tuple + +from bidict import bidict +from termcolor import colored + +sys.path.append("../../autoware/build/tracetools_read/") +sys.path.append("../../autoware/build/tracetools_analysis/") + +from clang_interop.cl_types import ClMethod, ClContext, ClSubscription +from tracing_interop.tr_types import TrContext, TrSubscriptionObject, TrSubscription, TrCallbackSymbol, TrTimer + + +class TKind(Enum): + # language=PythonRegExp + identifier = r"(?P(?:[\w$-]+::)*[\w$-]+|[+-]?[0-9]+|\".*?\"|'.*?')" + # language=PythonRegExp + ang_open = r"(?P<)" + # language=PythonRegExp + ang_close = r"(?P>)" + # language=PythonRegExp + par_open = r"(?P\()" + # language=PythonRegExp + par_close = r"(?P\))" + # language=PythonRegExp + hash = r"(?P#)" + # language=PythonRegExp + curl_open = r"(?P\{)" + # language=PythonRegExp + curl_close = r"(?P})" + # language=PythonRegExp + brack_open = r"(?P\[)" + # language=PythonRegExp + brack_close = r"(?P])" + # language=PythonRegExp + whitespace = r"(?P\s+)" + # language=PythonRegExp + ref = r"(?P&)" + # language=PythonRegExp + ptr = r"(?P\*)" + # language=PythonRegExp + comma = r"(?P,)" + # language=PythonRegExp + ns_sep = r"(?P::)" + # language=PythonRegExp + unknown_symbol = r"(?P\?)" + + def __repr__(self): + return self.name + + +class ASTEntry: + def get_token_stream(self): + pass + + +@dataclass +class ASTLeaf(ASTEntry): + kind: TKind + spelling: str + + def get_token_stream(self): + return [self] + + def __repr__(self): + return self.spelling if self.kind != TKind.identifier else f'"{self.spelling}"' + + +@dataclass +class ASTNode(ASTEntry): + type: str + children: List[ASTEntry] = field(default_factory=list) + parent: Optional['ASTNode'] = field(default_factory=lambda: None) + start: ASTLeaf | None = field(default_factory=lambda: None) + end: ASTLeaf | None = field(default_factory=lambda: None) + + def get_token_stream(self) -> List[ASTLeaf]: + stream = [] + for c in self.children: + stream += c.get_token_stream() + if self.start: + stream.insert(0, self.start) + if self.end: + stream.append(self.end) + return stream + + def __repr__(self): + tokens = self.get_token_stream() + ret = "" + last_tkind = None + for t in tokens: + match t.kind: + case TKind.identifier: + if last_tkind == TKind.identifier: + ret += " " + ret += t.spelling + last_tkind = t.kind + return ret + + +BRACK_MAP = bidict({ + TKind.curl_open: TKind.curl_close, + TKind.par_open: TKind.par_close, + TKind.brack_open: TKind.brack_close, + TKind.ang_open: TKind.ang_close +}) + +BRACK_SPELLING_MAP = bidict({ + '{': '}', + '(': ')', + '[': ']', + '<': '>' +}) + +TR_BLACKLIST = [ + "tf2_ros::TransformListener", + "rtc_auto_approver::RTCAutoApproverInterface", + "rclcpp::TimeSource", + "rclcpp::ParameterService", + "rclcpp_components::ComponentManager", + "rosbag2", + "std_srvs::srv" +] + + +def cl_deps_to_tr_deps(matches: Set[Tuple], tr: TrContext, cl: ClContext): + + ################################################## + # Narrow down matches + ################################################## + # The `match()` function returns an n-to-m + # mapping between cl and tr symbols. + # This n-to-m mapping has to be narrowed down + # to 1-to-1. This is done by building cohorts + # of cl symbols which belong to the same node + # and filtering out outliers. + ################################################## + + final_tr_deps = dict() + for cl_cb_id, cl_dep_ids in cl.dependencies.items(): + cl_cb_id: int + cl_dep_ids: Iterable[int] + + ################################################## + # 1. For all cl dependencies, build all possible + # tr dependencies with the matches we have. + ################################################## + + cl_cb = next(filter(lambda m: m.id == cl_cb_id, cl.methods), None) + cl_deps = set(filter(lambda m: m.id in cl_dep_ids, cl.methods)) + + if cl_cb is None or len(cl_deps) < len(cl_deps): + print(colored(f"[ERROR][CL] Callback has not all CL methods defined", "red")) + if cl_cb is None: + continue # for cl.dependencies.items() + + # Because the mapping is n-to-m, we have tr_cbs as a set, instead of a single tr_cb + tr_cbs = set(tr_obj for cl_obj, tr_obj, *_ in matches if cl_obj == cl_cb) + tr_deps = set(tr_obj for cl_obj, tr_obj, *_ in matches if cl_obj in cl_deps) + + ################################################## + # 2. Filter out all combinations where + # dependencies leave a node. + ################################################## + + def owner_node(sym: TrCallbackSymbol): + cb_objs = sym.callback_objs + owners = [cb_obj.owner for cb_obj in cb_objs] + owner_nodes = set() + for owner in owners: + match owner: + case TrSubscriptionObject() as sub: + sub: TrSubscriptionObject + owner_nodes.add(sub.subscription.node) + case TrTimer() as tmr: + tmr: TrTimer + owner_nodes.update(tmr.nodes) + + if len(owner_nodes) == 1: + return owner_nodes.pop() + + return None + + viable_matchings = {} + for tr_cb in tr_cbs: + tr_cb: TrCallbackSymbol + owner = owner_node(tr_cb) + if owner is None: + continue # for tr_cbs + + valid_deps = set(dep for dep in tr_deps if owner_node(dep) == owner) + if not valid_deps: + continue # for tr_cbs + + viable_matchings[tr_cb] = valid_deps + + if not viable_matchings: + print(colored(f"[ERROR][CL] Callback has not all TR equivalents for CL: {cl_cb.signature}", "red")) + continue # for cl.dependencies.items() + + ################################################## + # 3. Select the matching with the highest number + # of mapped dependencies (= the smallest number + # of unmapped cl deps) + ################################################## + + print(len(viable_matchings), ', '.join(map(str, map(len, viable_matchings.values())))) + + final_tr_deps.update(viable_matchings) + + return final_tr_deps + + +def match(tr: TrContext, cl: ClContext): + def _is_excluded(symbol: str): + return any(item in symbol for item in TR_BLACKLIST) + + cl_methods = [cb for cb in cl.methods + if any(sub.callback_id == cb.id for sub in cl.subscriptions) + or any(tmr.callback_id == cb.id for tmr in cl.timers)] + + tr_callbacks = [(sym.symbol, sym) for sym in tr.callback_symbols.values() if not _is_excluded(sym.symbol)] + cl_callbacks = [(cb.signature, cb) for cb in cl_methods] + + tr_callbacks = [(repr(sanitize(k)), v) for k, v in tr_callbacks] + cl_callbacks = [(repr(sanitize(k)), v) for k, v in cl_callbacks] + + matches_sig = set() + + tr_matched = set() + cl_matched = set() + + for cl_sig, cl_obj in cl_callbacks: + matches = set(tr_obj for tr_sig, tr_obj in tr_callbacks if tr_sig == cl_sig) + tr_matched |= matches + if matches: + cl_matched.add(cl_obj) + for tr_obj in matches: + matches_sig.add((cl_obj, tr_obj, cl_sig)) + + matches_topic = set() + for _, cl_obj in cl_callbacks: + # Get subscription of the callback (if any) + cl_sub: ClSubscription | None = next((sub for sub in cl.subscriptions if sub.callback_id == cl_obj.id), None) + + if not cl_sub: + continue + + cl_topic = re.sub(r"~/(input/)?", "", cl_sub.topic) + + matches = set() + for _, tr_obj in tr_callbacks: + tr_cb = tr_obj.callback_objs[0] if len(tr_obj.callback_objs) == 1 else None + if not tr_cb: + continue + + match tr_cb.owner: + case TrSubscriptionObject(subscription=tr_sub): + tr_sub: TrSubscription + tr_topic = tr_sub.topic_name + if not tr_topic: + continue + case _: + continue + + if tr_topic.endswith(cl_topic): + matches_topic.add((cl_obj, tr_obj, cl_topic, tr_topic)) + matches.add(tr_obj) + + tr_matched |= matches + if matches: + cl_matched.add(cl_obj) + + all_matches = matches_sig | matches_topic + + def count_dup(matches): + cl_dup = 0 + tr_dup = 0 + for (cl_obj, tr_obj, *_) in matches: + n_cl_dups = len([cl2 for cl2, *_ in matches if cl2 == cl_obj]) + if n_cl_dups > 1: + cl_dup += 1 / n_cl_dups + + n_tr_dups = len([tr2 for _, tr2, *_ in matches if tr2 == tr_obj]) + if n_tr_dups > 1: + tr_dup += 1 / n_tr_dups + + print(int(cl_dup), int(tr_dup)) + + count_dup(all_matches) + + tr_unmatched = set(tr_obj for _, tr_obj in tr_callbacks) - tr_matched + cl_unmatched = set(cl_obj for _, cl_obj in cl_callbacks) - cl_matched + + return all_matches, tr_unmatched, cl_unmatched + + +def match_and_modify_children(node: ASTEntry, match_func): + if not isinstance(node, ASTNode): + return node + + for i in range(len(node.children)): + seq_head = node.children[:i] + seq_tail = node.children[i:] + match_result = match_func(seq_head, seq_tail, node) + if match_result is not None: + node.children = match_result + + return node + + +def sanitize(sig: str): + ast = build_ast(sig) + + def _remove_qualifiers(node: ASTEntry): + match node: + case ASTLeaf(TKind.identifier, 'class' | 'struct' | 'const'): + return None + return node + + def _remove_std_wrappers(node: ASTEntry): + def _child_seq_matcher(head, tail, _): + match tail: + case [ASTLeaf(TKind.identifier, "std::allocator"), ASTNode('<>'), *rest]: + return head + rest + case [ASTLeaf(TKind.identifier, "std::shared_ptr"), ASTNode('<>', ptr_type), *rest]: + return head + ptr_type + rest + return None + + return match_and_modify_children(node, _child_seq_matcher) + + def _remove_std_bind(node: ASTEntry): + def _child_seq_matcher(head, tail, parent): + match tail: + case [ASTLeaf(TKind.identifier, "std::_Bind"), + ASTNode(type='<>', children=[ + callee_ret, + ASTNode('()', children=[*callee_ptr, ASTNode('()', bind_args)]), + ASTNode('()') as replacement_args])]: + + return [callee_ret] + head + [ASTNode('()', callee_ptr, parent, + ASTLeaf(TKind.par_open, '('), + ASTLeaf(TKind.par_close, ')')), + replacement_args] + return None + + return match_and_modify_children(node, _child_seq_matcher) + + def _unwrap_lambda(node: ASTEntry): + def _child_seq_matcher(head, tail, parent): + match tail: + case [ASTNode(type='()') as containing_method_args, + ASTLeaf(TKind.ns_sep), + ASTNode(type='{}', + children=[ + ASTLeaf(TKind.identifier, "lambda"), + ASTNode(type='()') as lambda_sig, + ASTLeaf(TKind.hash), + ASTLeaf(TKind.identifier) + ]), + *_]: + return [ASTLeaf(TKind.identifier, "void")] + \ + [ASTNode('()', + head + [containing_method_args], + parent, + ASTLeaf(TKind.par_open, '('), + ASTLeaf(TKind.par_close, ')'))] + \ + [lambda_sig] + tail[3:] + + return None + + return match_and_modify_children(node, _child_seq_matcher) + + def _remove_artifacts(node: ASTEntry): + def _child_seq_matcher(head, tail, _): + match tail: + case [ASTLeaf(TKind.ns_sep), ASTLeaf(TKind.ref | TKind.ptr), *rest]: + return head + rest + return None + + match node: + case ASTLeaf(TKind.identifier, spelling): + return ASTLeaf(TKind.identifier, re.sub(r"(_|const)$", "", spelling)) + case ASTNode('<>', []): + return None + case ASTNode(): + return match_and_modify_children(node, _child_seq_matcher) + case ASTLeaf(TKind.ref | TKind.ptr | TKind.ns_sep | TKind.unknown_symbol): + return None + return node + + def _replace_verbose_types(node: ASTEntry): + match node: + case ASTLeaf(TKind.identifier, "_Bool"): + return ASTLeaf(TKind.identifier, "bool") + + return node + + def _replace_lambda_enumerations(node: ASTEntry): + match node: + case ASTNode(children=[*_, ASTLeaf(TKind.identifier, idf)]) as node: + if re.fullmatch(r"\$_[0-9]+", idf): + node.children = node.children[:-1] + [ASTLeaf(TKind.identifier, "$lambda")] + + return node + + def _remove_return_types(node: ASTEntry): + match node: + case ASTNode("ast", [ASTLeaf(TKind.identifier), qualified_name, ASTNode('()') as params]) as node: + match qualified_name: + case ASTNode('()', name_unwrapped): + qualified_name = name_unwrapped + case _: + qualified_name = [qualified_name] + node.children = qualified_name + [params] + + return node + + ast = traverse(ast, _remove_qualifiers) + ast = traverse(ast, _remove_std_wrappers) + ast = traverse(ast, _remove_std_bind) + ast = traverse(ast, _unwrap_lambda) + ast = traverse(ast, _remove_artifacts) + ast = traverse(ast, _replace_verbose_types) + ast = traverse(ast, _replace_lambda_enumerations) + #ast = _remove_return_types(ast) + return ast + + +def traverse(node: ASTEntry, action) -> ASTEntry | None: + match node: + case ASTNode(): + children = [] + for c in node.children: + c = traverse(c, action) + match c: + case list(): + children += c + case None: + pass + case _: + children.append(c) + + node.children = children + return action(node) + + +def build_ast(sig: str): + tokens = tokenize(sig) + + ast = ASTNode("ast", [], None) + parens_stack = [] + current_node = ast + for token in tokens: + match token.kind: + case TKind.ang_open | TKind.curl_open | TKind.brack_open | TKind.par_open: + parens_stack.append(token.kind) + brack_content_ast_node = ASTNode(f"{token.spelling}{BRACK_SPELLING_MAP[token.spelling]}", + [], + current_node, + start=token, + end=ASTLeaf(BRACK_MAP[token.kind], BRACK_SPELLING_MAP[token.spelling])) + current_node.children.append(brack_content_ast_node) + current_node = brack_content_ast_node + case TKind.ang_close | TKind.curl_close | TKind.brack_close | TKind.par_close: + if not parens_stack or BRACK_MAP.inv[token.kind] != parens_stack[-1]: + expect_str = parens_stack[-1] if parens_stack else "nothing" + raise ValueError( + f"Invalid brackets: encountered {token.spelling} when expecting {expect_str} in '{sig}'") + parens_stack.pop() + current_node = current_node.parent + case TKind.whitespace: + continue + case _: + current_node.children.append(token) + + if parens_stack: + raise ValueError(f"Token stream finished but unclosed brackets remain: {parens_stack} in '{sig}'") + + return ast + + +def tokenize(sig: str) -> List[ASTLeaf]: + token_matchers = [t.value for t in TKind] + tokens = list(re.finditer('|'.join(token_matchers), sig)) + + prev_end = 0 + for t in tokens: + t_start, t_end = t.span() + if t_start != prev_end: + raise ValueError(f"Tokenizer failed at char {t_start}: '{sig}'") + prev_end = t_end + + if prev_end != len(sig): + raise ValueError(f"Tokenization not exhaustive for: '{sig}'") + + tokens = [tuple(next(filter(lambda pair: pair[-1] is not None, t.groupdict().items()))) for t in tokens] + tokens = [ASTLeaf(TKind.__members__[k], v) for k, v in tokens] + return tokens + + +if __name__ == "__main__": + with open("../cache/cl_objects_7b616c9c48.pkl", "rb") as f: + print("Loading Clang Objects... ", end='') + cl: ClContext = pickle.load(f) + print("Done.") + + with open("../cache/tr_objects_c1e0d50b8d.pkl", "rb") as f: + print("Loading Tracing Objects... ", end='') + tr: TrContext = pickle.load(f) + print("Done.") + + matches, _, _ = match(tr, cl) + cl_deps_to_tr_deps(matches, tr, cl) diff --git a/misc/utils.py b/misc/utils.py index 2b68226..34eb34e 100644 --- a/misc/utils.py +++ b/misc/utils.py @@ -50,10 +50,22 @@ def cached(name, function, file_deps: List[str]): dep_time = 0.0 for file in file_deps: + # Get modified time of the current dependency m_time = os.path.getmtime(file) if os.path.exists(file) else 0. + + # Update dependency time to be the newest modified time of any dependency if m_time > dep_time: dep_time = m_time + # Check directories recursively to get the newest modified time + for root, dirs, files in os.walk(file): + for f in files + dirs: + filename = os.path.join(root, f) + m_time = os.path.getmtime(filename) + + if m_time > dep_time: + dep_time = m_time + deps_hash = stable_hash(sorted(file_deps)) pkl_filename = f"cache/{name}_{deps_hash}.pkl" diff --git a/requirements.txt b/requirements.txt index 7343431..76e0bb5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ matplotlib pyvis graphviz ruamel.yaml +fuzzywuzzy diff --git a/tracing_interop/types.py b/tracing_interop/tr_types.py similarity index 84% rename from tracing_interop/types.py rename to tracing_interop/tr_types.py index 88f306d..e14d3d6 100644 --- a/tracing_interop/types.py +++ b/tracing_interop/tr_types.py @@ -1,8 +1,9 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import cached_property from typing import List, Dict import pandas as pd +from tqdm.notebook import tqdm from tracetools_analysis.processor.ros2 import Ros2Handler from tracetools_analysis.utils.ros2 import Ros2DataModelUtil @@ -62,19 +63,26 @@ class TrContext: print("[TrContext] Caching dynamic properties...") - [(o.path, o.publishers, o.subscriptions, o.timers) for o in self.nodes.values()] + p = tqdm(desc=" ├─ Processing nodes", total=len(self.nodes.values())) + [(o.path, o.publishers, o.subscriptions, o.timers, p.update()) for o in self.nodes.values()] print(" ├─ Cached node properties") - [(o.instances, o.subscriptions) for o in self.publishers.values()] + p = tqdm(desc=" ├─ Processing publishers", total=len(self.publishers.values())) + [(o.instances, o.subscriptions, p.update()) for o in self.publishers.values()] print(" ├─ Cached publisher properties") - [(o.publishers, o.subscription_objects) for o in self.subscriptions.values()] + p = tqdm(desc=" ├─ Processing subscriptions", total=len(self.subscriptions.values())) + [(o.publishers, o.subscription_objects, p.update()) for o in self.subscriptions.values()] print(" ├─ Cached subscription properties") - [(o.nodes) for o in self.timers.values()] + p = tqdm(desc=" ├─ Processing timers", total=len(self.timers.values())) + [(o.nodes, p.update()) for o in self.timers.values()] print(" ├─ Cached timer properties") - [(o.callback_instances, o.owner, o.owner_info) for o in self.callback_objects.values()] + p = tqdm(desc=" ├─ Processing CB objects", total=len(self.callback_objects.values())) + [(o.callback_instances, o.owner, p.update()) for o in self.callback_objects.values()] print(" ├─ Cached callback object properties") - [(o.callback_objs) for o in self.callback_symbols.values()] + p = tqdm(desc=" ├─ Processing CB symbols", total=len(self.callback_symbols.values())) + [(o.callback_objs, p.update()) for o in self.callback_symbols.values()] print(" ├─ Cached callback symbol properties") - [(o.publishers, o.subscriptions) for o in self.topics.values()] + p = tqdm(desc=" ├─ Processing topics", total=len(self.topics.values())) + [(o.publishers, o.subscriptions, p.update()) for o in self.topics.values()] print(" └─ Cached topic properties\n") def __getstate__(self): @@ -88,6 +96,9 @@ class TrContext: self.util = None self.handler = None + def __repr__(self): + return f"TrContext" + @dataclass class TrNode: @@ -97,11 +108,11 @@ class TrNode: rmw_handle: int name: str namespace: str - _c: TrContext + _c: TrContext = field(repr=False) @cached_property def path(self) -> str: - return '/'.join((self.namespace, self.name)) + return '/'.join((self.namespace, self.name)).replace('//', '/') @cached_property def publishers(self) -> List['TrPublisher']: @@ -128,7 +139,7 @@ class TrPublisher: rmw_handle: int topic_name: str depth: int - _c: TrContext + _c: TrContext = field(repr=False) @property def node(self) -> 'TrNode': @@ -158,7 +169,7 @@ class TrSubscription: rmw_handle: int topic_name: str depth: int - _c: TrContext + _c: TrContext = field(repr=False) @property def node(self) -> 'TrNode': @@ -187,7 +198,7 @@ class TrTimer: timestamp: int period: int tid: int - _c: TrContext + _c: TrContext = field(repr=False) @cached_property def nodes(self) -> List['TrNode']: @@ -214,7 +225,7 @@ class TrSubscriptionObject: id: int # subscription timestamp: int subscription_handle: int - _c: TrContext + _c: TrContext = field(repr=False) @property def subscription(self) -> 'TrSubscription': @@ -233,7 +244,7 @@ class TrCallbackObject: id: int # (reference) = subscription_object.id | timer.id | .... timestamp: int callback_object: int - _c: TrContext + _c: TrContext = field(repr=False) @cached_property def callback_instances(self) -> List['TrCallbackInstance']: @@ -257,17 +268,6 @@ class TrCallbackObject: return 'Client' return None - @cached_property - def owner_info(self): - info = self._c.util.get_callback_owner_info(self.callback_object) - if info is None: - return None, None - - type_name, dict_str = info.split(" -- ") - kv_strs = dict_str.split(", ") - info_dict = {k: v for k, v in map(lambda kv_str: kv_str.split(": ", maxsplit=1), kv_strs)} - return type_name, info_dict - def __hash__(self): return hash((self.id, self.timestamp, self.callback_object)) @@ -277,7 +277,7 @@ class TrPublishInstance: publisher_handle: int timestamp: int message: int - _c: TrContext + _c: TrContext = field(repr=False) @property def publisher(self) -> 'TrPublisher': @@ -293,7 +293,7 @@ class TrCallbackInstance: timestamp: pd.Timestamp duration: pd.Timedelta intra_process: bool - _c: TrContext + _c: TrContext = field(repr=False) @property def callback_obj(self) -> 'TrCallbackObject': @@ -308,7 +308,7 @@ class TrCallbackSymbol: id: int # callback_object timestamp: int symbol: str - _c: TrContext + _c: TrContext = field(repr=False) @cached_property def callback_objs(self) -> List['TrCallbackObject']: @@ -325,7 +325,7 @@ class TrCallbackSymbol: @dataclass class TrTopic: name: str - _c: TrContext + _c: TrContext = field(repr=False) @cached_property def publishers(self) -> List['TrPublisher']: diff --git a/tracing_interop/utils.py b/tracing_interop/utils.py index 05b2329..7c91a4b 100644 --- a/tracing_interop/utils.py +++ b/tracing_interop/utils.py @@ -1,15 +1,25 @@ import sys import pandas as pd +from tqdm.notebook import tqdm -def row_to_type(row, type, has_idx, **type_kwargs): - return type(id=row.name, **row, **type_kwargs) if has_idx else type(**row, **type_kwargs) +def row_to_type(row, type, **type_kwargs): + return type(**row, **type_kwargs) def df_to_type_list(df, type, **type_kwargs): has_idx = not isinstance(df.index, pd.RangeIndex) - return [row_to_type(row, type, has_idx, **type_kwargs) for _, row in df.iterrows()] + ret_list = [] + p = tqdm(desc=" ├─ Processing", total=len(df)) + for row in df.itertuples(index=has_idx): + p.update() + row_dict = row._asdict() + if has_idx: + row_dict["id"] = row.Index + del row_dict["Index"] + ret_list.append(row_to_type(row_dict, type, **type_kwargs)) + return ret_list def by_index(df, index, type):