dataflow-analysis/matching/subscriptions.py
Maximilian Schmeller 4ddc69562f More documentation
2022-12-29 15:09:59 +09:00

567 lines
20 KiB
Python

import pickle
import re
import sys
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional, Iterable, Set, Tuple
from bidict import bidict
from termcolor import colored
sys.path.append("../../autoware/build/tracetools_read/")
sys.path.append("../../autoware/build/tracetools_analysis/")
from clang_interop.cl_types import ClMethod, ClContext, ClSubscription
from tracing_interop.tr_types import TrContext, TrSubscriptionObject, TrSubscription, TrCallbackSymbol, TrTimer
class TKind(Enum):
# language=PythonRegExp
identifier = r"(?P<identifier>(?:[\w$-]+::)*[\w$-]+|[+-]?[0-9]+|\".*?\"|'.*?')"
# language=PythonRegExp
ang_open = r"(?P<ang_open><)"
# language=PythonRegExp
ang_close = r"(?P<ang_close>>)"
# language=PythonRegExp
par_open = r"(?P<par_open>\()"
# language=PythonRegExp
par_close = r"(?P<par_close>\))"
# language=PythonRegExp
hash = r"(?P<hash>#)"
# language=PythonRegExp
curl_open = r"(?P<curl_open>\{)"
# language=PythonRegExp
curl_close = r"(?P<curl_close>})"
# language=PythonRegExp
brack_open = r"(?P<brack_open>\[)"
# language=PythonRegExp
brack_close = r"(?P<brack_close>])"
# language=PythonRegExp
whitespace = r"(?P<whitespace>\s+)"
# language=PythonRegExp
ref = r"(?P<ref>&)"
# language=PythonRegExp
ptr = r"(?P<ptr>\*)"
# language=PythonRegExp
comma = r"(?P<comma>,)"
# language=PythonRegExp
ns_sep = r"(?P<ns_sep>::)"
# language=PythonRegExp
unknown_token = r"(?P<unknown_token>.)"
def __repr__(self):
return self.name
class ASTEntry:
def get_token_stream(self):
pass
@dataclass
class ASTLeaf(ASTEntry):
kind: TKind
spelling: str
def get_token_stream(self):
return [self]
def __repr__(self):
return self.spelling if self.kind != TKind.identifier else f'"{self.spelling}"'
@dataclass
class ASTNode(ASTEntry):
type: str
children: List[ASTEntry] = field(default_factory=list)
parent: Optional['ASTNode'] = field(default_factory=lambda: None)
start: ASTLeaf | None = field(default_factory=lambda: None)
end: ASTLeaf | None = field(default_factory=lambda: None)
def get_token_stream(self) -> List[ASTLeaf]:
stream = []
for c in self.children:
stream += c.get_token_stream()
if self.start:
stream.insert(0, self.start)
if self.end:
stream.append(self.end)
return stream
def __repr__(self):
tokens = self.get_token_stream()
ret = ""
last_tkind = None
for t in tokens:
match t.kind:
case TKind.identifier:
if last_tkind == TKind.identifier:
ret += " "
ret += t.spelling
last_tkind = t.kind
return ret
BRACK_MAP = bidict({
TKind.curl_open: TKind.curl_close,
TKind.par_open: TKind.par_close,
TKind.brack_open: TKind.brack_close,
TKind.ang_open: TKind.ang_close
})
BRACK_SPELLING_MAP = bidict({
'{': '}',
'(': ')',
'[': ']',
'<': '>'
})
TR_BLACKLIST = [
"tf2_ros::TransformListener",
"rtc_auto_approver::RTCAutoApproverInterface",
"rclcpp::TimeSource",
"rclcpp::ParameterService",
"rclcpp_components::ComponentManager",
"rosbag2",
"std_srvs::srv"
]
def cl_deps_to_tr_deps(matches: Set[Tuple], tr: TrContext, cl: ClContext):
##################################################
# Narrow down matches
##################################################
# The `match()` function returns an n-to-m
# mapping between cl and tr symbols.
# This n-to-m mapping has to be narrowed down
# to 1-to-1. This is done by building cohorts
# of cl symbols which belong to the same node
# and filtering out outliers.
##################################################
final_tr_deps = dict()
for cl_cb_id, cl_dep_ids in cl.dependencies.items():
cl_cb_id: int
cl_dep_ids: Iterable[int]
##################################################
# 1. For all cl dependencies, build all possible
# tr dependencies with the matches we have.
##################################################
cl_cb = next(filter(lambda m: m.id == cl_cb_id, cl.methods), None)
cl_deps = set(filter(lambda m: m.id in cl_dep_ids, cl.methods))
if cl_cb is None or len(cl_deps) < len(cl_deps):
print(colored(f"[ERROR][CL] Callback has not all CL methods defined", "red"))
if cl_cb is None:
continue # for cl.dependencies.items()
# Because the mapping is n-to-m, we have tr_cbs as a set, instead of a single tr_cb
tr_cbs = set(tr_obj for cl_obj, tr_obj, *_ in matches if cl_obj == cl_cb)
tr_deps = set(tr_obj for cl_obj, tr_obj, *_ in matches if cl_obj in cl_deps)
##################################################
# 2. Filter out all combinations where
# dependencies leave a node.
##################################################
def owner_node(sym: TrCallbackSymbol):
cb_obj = sym.callback_obj
owner = cb_obj.owner
match owner:
case TrSubscriptionObject() as sub:
sub: TrSubscriptionObject
return sub.subscription.node
case TrTimer() as tmr:
tmr: TrTimer
return tmr.node
viable_matchings = {}
for tr_cb in tr_cbs:
tr_cb: TrCallbackSymbol
owner = owner_node(tr_cb)
if owner is None:
continue # for tr_cbs
valid_deps = set(dep for dep in tr_deps if owner_node(dep) == owner)
if not valid_deps:
continue # for tr_cbs
viable_matchings[tr_cb] = valid_deps
if not viable_matchings:
print(colored(f"[ERROR][CL] Callback has not all TR equivalents for CL: {cl_cb.signature}", "red"))
continue # for cl.dependencies.items()
##################################################
# 3. Select the matching with the highest number
# of mapped dependencies (= the smallest number
# of unmapped cl deps)
##################################################
print(len(viable_matchings), ', '.join(map(str, map(len, viable_matchings.values()))))
final_tr_deps.update(viable_matchings)
return final_tr_deps
def match(tr: TrContext, cl: ClContext):
def _is_excluded(symbol: str):
return any(item in symbol for item in TR_BLACKLIST)
cl_methods = [cb for cb in cl.methods
if any(sub.callback_id == cb.id for sub in cl.subscriptions)
or any(tmr.callback_id == cb.id for tmr in cl.timers)]
tr_callbacks = [(sym.symbol, sym) for sym in tr.callback_symbols.values() if not _is_excluded(sym.symbol)]
cl_callbacks = [(cb.signature, cb) for cb in cl_methods]
tr_callbacks = [(repr(sanitize(k)), v) for k, v in tr_callbacks]
cl_callbacks = [(repr(sanitize(k)), v) for k, v in cl_callbacks]
matches_sig = set()
tr_matched = set()
cl_matched = set()
for cl_sig, cl_obj in cl_callbacks:
matches = set(tr_obj for tr_sig, tr_obj in tr_callbacks if tr_sig == cl_sig)
tr_matched |= matches
if matches:
cl_matched.add(cl_obj)
for tr_obj in matches:
matches_sig.add((cl_obj, tr_obj, cl_sig))
matches_topic = set()
for _, cl_obj in cl_callbacks:
# Get subscription of the callback (if any)
cl_sub: ClSubscription | None = next((sub for sub in cl.subscriptions if sub.callback_id == cl_obj.id), None)
if not cl_sub:
continue
cl_topic = re.sub(r"~/(input/)?", "", cl_sub.topic)
matches = set()
for _, tr_obj in tr_callbacks:
tr_cb = tr_obj.callback_objs[0] if len(tr_obj.callback_objs) == 1 else None
if not tr_cb:
continue
match tr_cb.owner:
case TrSubscriptionObject(subscription=tr_sub):
tr_sub: TrSubscription
tr_topic = tr_sub.topic_names # TODO: str->list type refactoring
if not tr_topic:
continue
case _:
continue
if tr_topic.endswith(cl_topic):
matches_topic.add((cl_obj, tr_obj, cl_topic, tr_topic))
matches.add(tr_obj)
tr_matched |= matches
if matches:
cl_matched.add(cl_obj)
all_matches = matches_sig | matches_topic
def count_dup(matches):
cl_dup = 0
tr_dup = 0
for (cl_obj, tr_obj, *_) in matches:
n_cl_dups = len([cl2 for cl2, *_ in matches if cl2 == cl_obj])
if n_cl_dups > 1:
cl_dup += 1 / n_cl_dups
n_tr_dups = len([tr2 for _, tr2, *_ in matches if tr2 == tr_obj])
if n_tr_dups > 1:
tr_dup += 1 / n_tr_dups
print(int(cl_dup), int(tr_dup))
count_dup(all_matches)
tr_unmatched = set(tr_obj for _, tr_obj in tr_callbacks) - tr_matched
cl_unmatched = set(cl_obj for _, cl_obj in cl_callbacks) - cl_matched
return all_matches, tr_unmatched, cl_unmatched
def match_and_modify_children(node: ASTEntry, match_func):
"""
If `node` is an `ASTNode`, try to apply `match_func` for every possible
tuple (chidren[:i], children[i:], node).
If `match_func` returns a non-None value, overwrite `node.children` with that result.
This function is not recursive but `match_func` can be.
Return `node` in all cases.
:param node: The node of which the children are applied `match_func` to
:param match_func: The match/modify function. Has to return either `None` or a list of `ASTEntry` and has to accept arguments `seq_head`, `seq_tail`, `node`.
:return: `node`
"""
if not isinstance(node, ASTNode):
return node
for i in range(len(node.children)):
seq_head = node.children[:i]
seq_tail = node.children[i:]
match_result = match_func(seq_head, seq_tail, node)
if match_result is not None:
node.children = match_result
return node
def sanitize(sig: str) -> ASTNode | None:
"""
Tokenizes and parses `sig` into pseudo-C++, then applies a set of rules on the resulting AST to filter unwanted boilerplate expressions.
Use `repr(sanitize(sig))` to get the sanitized `sig` as a string.
:param sig: A function signature, e.g. from tracing or Clang data.
:return: The sanitized AST.
"""
try:
ast = build_ast(sig)
except Exception as e:
print(f"[ERROR] Could not build AST for {sig}, returning it as-is. {e}")
return None
def _remove_qualifiers(node: ASTEntry):
match node:
case ASTLeaf(TKind.identifier, 'class' | 'struct' | 'const'):
return None
return node
def _remove_std_wrappers(node: ASTEntry):
def _child_seq_matcher(head, tail, _):
match tail:
case [ASTLeaf(TKind.identifier, "std::allocator"), ASTNode('<>'), *rest]:
return head + rest
case [ASTLeaf(TKind.identifier, "std::shared_ptr"), ASTNode('<>', ptr_type), *rest]:
return head + ptr_type + rest
return None
return match_and_modify_children(node, _child_seq_matcher)
def _remove_std_bind(node: ASTEntry):
def _child_seq_matcher(head, tail, parent):
match tail:
case [ASTLeaf(TKind.identifier, "std::_Bind"),
ASTNode(type='<>', children=[
callee_ret,
ASTNode('()', children=[*callee_ptr, ASTNode('()', bind_args)]),
ASTNode('()') as replacement_args])]:
return [callee_ret] + head + [ASTNode('()', callee_ptr, parent,
ASTLeaf(TKind.par_open, '('),
ASTLeaf(TKind.par_close, ')')),
replacement_args]
return None
return match_and_modify_children(node, _child_seq_matcher)
def _unwrap_lambda(node: ASTEntry):
def _child_seq_matcher(head, tail, parent):
match tail:
case [ASTNode(type='()') as containing_method_args,
ASTLeaf(TKind.ns_sep),
ASTNode(type='{}',
children=[
ASTLeaf(TKind.identifier, "lambda"),
ASTNode(type='()') as lambda_sig,
ASTLeaf(TKind.hash),
ASTLeaf(TKind.identifier)
]),
*_]:
return [ASTLeaf(TKind.identifier, "void")] + \
[ASTNode('()',
head + [containing_method_args],
parent,
ASTLeaf(TKind.par_open, '('),
ASTLeaf(TKind.par_close, ')'))] + \
[lambda_sig] + tail[3:]
return None
return match_and_modify_children(node, _child_seq_matcher)
def _remove_artifacts(node: ASTEntry):
def _child_seq_matcher(head, tail, _):
match tail:
case [ASTLeaf(TKind.ns_sep), ASTLeaf(TKind.ref | TKind.ptr), *rest]:
return head + rest
return None
match node:
case ASTLeaf(TKind.identifier, spelling):
return ASTLeaf(TKind.identifier, re.sub(r"(_|const)$", "", spelling))
case ASTNode('<>', []):
return None
case ASTNode():
return match_and_modify_children(node, _child_seq_matcher)
case ASTLeaf(TKind.ref | TKind.ptr | TKind.ns_sep | TKind.unknown_token):
return None
return node
def _replace_verbose_types(node: ASTEntry):
match node:
case ASTLeaf(TKind.identifier, "_Bool"):
return ASTLeaf(TKind.identifier, "bool")
return node
def _replace_lambda_enumerations(node: ASTEntry):
match node:
case ASTNode(children=[*_, ASTLeaf(TKind.identifier, idf)]) as node:
if re.fullmatch(r"\$_[0-9]+", idf):
node.children = node.children[:-1] + [ASTLeaf(TKind.identifier, "$lambda")]
return node
def _remove_return_types(node: ASTEntry):
match node:
case ASTNode("ast", [ASTLeaf(TKind.identifier), qualified_name, ASTNode('()') as params]) as node:
match qualified_name:
case ASTNode('()', name_unwrapped):
qualified_name = name_unwrapped
case _:
qualified_name = [qualified_name]
node.children = qualified_name + [params]
return node
sanitization_actions = [
_remove_qualifiers,
_remove_std_wrappers,
_remove_std_bind,
_unwrap_lambda,
_remove_artifacts,
_replace_verbose_types,
_replace_lambda_enumerations
]
for action in sanitization_actions:
try:
ast = traverse(ast, action)
except Exception as e:
print(f"[WARN] Could not apply {action.__name__} sanitization, skipped: {e}")
return ast
def traverse(node: ASTEntry, action) -> ASTEntry | None:
"""
Traverse every node in `node`'s subtree, including itself, and apply `action` to it.
Return the result of `action(node)` recursively.
:param node: The `ASTEntry` to start traversal at
:param action: A transformation taking one `ASTEntry` argument and returning an `ASTEntry` or `None`.
:return: The transformed AST or `None`
"""
match node:
case ASTNode():
children = []
for c in node.children:
c = traverse(c, action)
match c:
case list():
children += c
case None:
pass
case _:
children.append(c)
node.children = children
return action(node)
def build_ast(sig: str) -> ASTNode:
"""
Tokenize and parse `sig` into a pseudo-C++ abstract syntax tree (AST).
:param sig: The signature to parse
:return: The corresponding AST
"""
tokens = tokenize(sig)
ast = ASTNode("ast", [], None)
# The state of brackets/braces/parens is stored as a stack. This stack has to be matched in reverse by tokens until
# all brackets/braces/parens are closed. Only contains opening ones, never closing ones.
parens_stack = []
# The current node is the subtree which is worked on currently, start at the top-level `ast`
current_node = ast
for token in tokens:
match token.kind:
# If the token is an opening bracket/brace/paren, add it to the stack and create a new AST node
# for its content. Work on that node next.
case TKind.ang_open | TKind.curl_open | TKind.brack_open | TKind.par_open:
parens_stack.append(token.kind)
brack_content_ast_node = ASTNode(f"{token.spelling}{BRACK_SPELLING_MAP[token.spelling]}",
[],
current_node,
start=token,
end=ASTLeaf(BRACK_MAP[token.kind], BRACK_SPELLING_MAP[token.spelling]))
current_node.children.append(brack_content_ast_node)
current_node = brack_content_ast_node
# If the token is a closing bracket/brace/paren, check if it matches the last open bracket on the stack
# and pop it. If it does not match, raise an error.
# Continue work on the parent node.
case TKind.ang_close | TKind.curl_close | TKind.brack_close | TKind.par_close:
if not parens_stack or BRACK_MAP.inv[token.kind] != parens_stack[-1]:
expect_str = parens_stack[-1] if parens_stack else "nothing"
raise ValueError(
f"Invalid brackets: encountered {token.spelling} when expecting {expect_str} in '{sig}'")
parens_stack.pop() # Remove
current_node = current_node.parent
# Ignore whitespace
case TKind.whitespace:
continue
# For any other token, add it as a child of the current node and continue processing the current node
case _:
current_node.children.append(token)
# If there are any open brackets remaining on the stack, raise a ValueError indicating unclosed brackets
if parens_stack:
raise ValueError(f"Token stream finished but unclosed brackets remain: {parens_stack} in '{sig}'")
return ast
def tokenize(sig: str) -> List[ASTLeaf]:
"""
Tokenize `sig` according to the ruled defined in`TKind`.
:param sig: The signature to tokenize
:return: The tokenization, as a list of `ASTLeaf`
"""
token_matchers = [t.value for t in TKind]
tokens = list(re.finditer('|'.join(token_matchers), sig))
prev_end = 0
for t in tokens:
t_start, t_end = t.span()
if t_start != prev_end:
raise ValueError(f"Tokenizer failed at char {t_start}: '{sig}'")
prev_end = t_end
if prev_end != len(sig):
raise ValueError(f"Tokenization not exhaustive for: '{sig}'")
tokens = [tuple(next(filter(lambda pair: pair[-1] is not None, t.groupdict().items()))) for t in tokens]
tokens = [ASTLeaf(TKind.__members__[k], v) for k, v in tokens]
return tokens
# Debug code to troubleshoot errors
if __name__ == "__main__":
with open("../cache/cl_objects_7b616c9c48.pkl", "rb") as f:
print("Loading Clang Objects... ", end='')
cl: ClContext = pickle.load(f)
print("Done.")
with open("../cache/tr_objects_c1e0d50b8d.pkl", "rb") as f:
print("Loading Tracing Objects... ", end='')
tr: TrContext = pickle.load(f)
print("Done.")
matches, _, _ = match(tr, cl)
cl_deps_to_tr_deps(matches, tr, cl)