Skip to content

Commit e840a3d

Browse files
Initial Wir to Dag Transformation for pandas (Closes stefan-grafberger#12, Closes stefan-grafberger#7) (PR stefan-grafberger#14)
* Works somewhat with sketchy revisiting, instrumented_calls can maybe get instrumented * Think I found the issue now * Capturing the module works now * works now but still some issues with duplicate instrumented_call insertions * Updated some comments * Module inspection now working as intended + cleanup * Starting to use networkx * networkx may be a real option * Coverted more tests to networkx * Coverted more tests to networkx * Completed switch to networkx * moved test for pipeline into pipelines dir * Simplified code * Code cleanup * Added module info to extracted WIR for calls. todo: subscript-index * Executor and Inspector entry point on string directly * Updated forgotten test * Refined capturing of function info * Subscript instrumentation works * Cleanup * trying to find a good solution for some instrumentation edge cases with nested function calls * found the issue. if ast call nodes in arg, they get recomputed. solution: instrumented after and before separately * Closer to intended instrumentation: before_call_used_value works * wip: full efficient call capturing * Getting closer * Efficient call capturing for call args works. TODO: subscript and kwargs * Call args capturing works for both subscripts and calls * Fixed minor subscript bug * kwargs capturing works * wip: code cleanup * Cleanup. Ast instrumentation finally works as intended * Add a few tests * More cleanup * Removed unused code * Minor cleanup * tests with spy to ensure instrumentation works * Moved annotating wir with module info to separate step * Prepared implementation of wir2dag transformer * Updated comment * wip: remove_all_nodes_but_calls_and_subscripts * wip: remove_all_nodes_but_calls_and_subscripts * getting closer * First draft. Only calls and subscripts survive, but some edges are missing. WIP: writing a test and debugging * Found bug, todo: polish test * test_remove_all_nodes_but_calls_and_subscripts works * Added module info to test_remove_all_nodes_but_calls_and_subscripts test * Potential minor performance optimisation (sets should be stored as a constant here) * wip: pandas dag extraction * wip: adding lineno and col_offset everywhere * wip: adding lineno and col_offset everywhere * Updated tests to include line numbers * added label binarize to list of known functions * Removed some unused code * Minor DagVertex __repr__ changes * PipelineInspector now returns extracted dag
1 parent 05edb1b commit e840a3d

19 files changed

+1181
-240
lines changed

mlinspect/instrumentation/call_capture_transformer.py

Lines changed: 123 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,131 @@ def visit_Call(self, node):
1414
"""
1515
Instrument all function calls
1616
"""
17-
# pylint: disable=no-self-use, invalid-name
17+
# pylint: disable=invalid-name
18+
ast.NodeTransformer.generic_visit(self, node)
19+
code = astunparse.unparse(node)
20+
21+
self.add_before_call_used_value_capturing_call(code, node)
22+
self.add_before_call_used_args_capturing_call(code, node)
23+
self.add_before_call_used_kwargs_capturing_call(code, node)
24+
instrumented_call_node = self.add_after_call_used_capturing(code, node, False)
25+
26+
return instrumented_call_node
1827

28+
def visit_Subscript(self, node):
29+
"""
30+
Instrument all subscript calls
31+
"""
32+
# pylint: disable=invalid-name
33+
ast.NodeTransformer.generic_visit(self, node)
1934
code = astunparse.unparse(node)
20-
args = ast.List(node.args, ctx=ast.Load())
21-
args_code = ast.List([ast.Str(astunparse.unparse(arg).split("\n", 1)[0]) for arg in node.args],
22-
ctx=ast.Load())
2335

24-
instrumented_call_node = ast.Call(func=ast.Name(id='instrumented_call_used', ctx=ast.Load()),
25-
args=[args, args_code, node, ast.Str(s=code)], keywords=[])
26-
ast.copy_location(instrumented_call_node, node)
36+
self.add_before_call_used_value_capturing_subscript(code, node)
37+
self.add_before_call_used_args_capturing_subscript(code, node)
38+
instrumented_call_node = self.add_after_call_used_capturing(code, node, True)
2739

28-
# TODO: warn if unrecognized function call, expressions
40+
return instrumented_call_node
41+
42+
@staticmethod
43+
def add_before_call_used_value_capturing_call(code, node):
44+
"""
45+
When the method of some object is called, capture the value of the object before executing the method
46+
"""
47+
if hasattr(node.func, "value"):
48+
old_value_node = node.func.value
49+
value_code = astunparse.unparse(old_value_node)
50+
new_value_node = ast.Call(func=ast.Name(id='before_call_used_value', ctx=ast.Load()),
51+
args=[ast.Constant(n=False, kind=None),
52+
ast.Constant(n=code, kind=None),
53+
ast.Constant(n=value_code, kind=None),
54+
old_value_node,
55+
ast.Constant(n=node.lineno, kind=None),
56+
ast.Constant(n=node.col_offset, kind=None)],
57+
keywords=[])
58+
node.func.value = new_value_node
59+
60+
@staticmethod
61+
def add_before_call_used_value_capturing_subscript(code, node):
62+
"""
63+
When the __getitem__ method of some object is called, capture the value of the object before executing the
64+
method
65+
"""
66+
old_value_node = node.value
67+
value_code = astunparse.unparse(old_value_node)
68+
new_value_node = ast.Call(func=ast.Name(id='before_call_used_value', ctx=ast.Load()),
69+
args=[ast.Constant(n=True, kind=None),
70+
ast.Constant(n=code, kind=None),
71+
ast.Constant(n=value_code, kind=None),
72+
old_value_node,
73+
ast.Constant(n=node.lineno, kind=None),
74+
ast.Constant(n=node.col_offset, kind=None)],
75+
keywords=[])
76+
node.value = new_value_node
77+
78+
@staticmethod
79+
def add_before_call_used_args_capturing_call(code, node):
80+
"""
81+
When a method is called, capture the arguments of the method before executing it
82+
"""
83+
old_args_nodes_ast = ast.List(node.args, ctx=ast.Load())
84+
old_args_code = ast.List([ast.Constant(n=astunparse.unparse(arg).split("\n", 1)[0], kind=None)
85+
for arg in node.args], ctx=ast.Load())
86+
new_args_node = ast.Starred(value=ast.Call(func=ast.Name(id='before_call_used_args', ctx=ast.Load()),
87+
args=[ast.Constant(n=False, kind=None),
88+
ast.Constant(n=code, kind=None),
89+
old_args_code,
90+
ast.Constant(n=node.lineno, kind=None),
91+
ast.Constant(n=node.col_offset, kind=None),
92+
old_args_nodes_ast],
93+
keywords=[]), ctx=ast.Load())
94+
node.args = [new_args_node]
95+
96+
@staticmethod
97+
def add_before_call_used_args_capturing_subscript(code, node):
98+
"""
99+
When the __getitem__ method of some object is called, capture the arguments of the method before executing it
100+
"""
101+
old_args_nodes_ast = ast.List([node.slice.value], ctx=ast.Load())
102+
old_args_code = ast.List([ast.Constant(n=astunparse.unparse(node.slice.value).split("\n", 1)[0], kind=None)],
103+
ctx=ast.Load())
104+
new_args_node = ast.Call(func=ast.Name(id='before_call_used_args', ctx=ast.Load()),
105+
args=[ast.Constant(n=True, kind=None),
106+
ast.Constant(n=code, kind=None),
107+
old_args_code,
108+
ast.Constant(n=node.lineno, kind=None),
109+
ast.Constant(n=node.col_offset, kind=None),
110+
old_args_nodes_ast],
111+
keywords=[])
112+
node.slice.value = new_args_node
113+
114+
@staticmethod
115+
def add_before_call_used_kwargs_capturing_call(code, node):
116+
"""
117+
When a method is called, capture the keyword arguments of the method before executing it
118+
"""
119+
old_kwargs_nodes_ast = node.keywords # old_kwargs_nodes_ast = ast.List(node.keywords, ctx=ast.Load())
120+
old_kwargs_code = ast.List([ast.Constant(n=astunparse.unparse(kwarg), kind=None)
121+
for kwarg in node.keywords], ctx=ast.Load())
122+
new_kwargs_node = ast.keyword(value=ast.Call(func=ast.Name(id='before_call_used_kwargs', ctx=ast.Load()),
123+
args=[ast.Constant(n=False, kind=None),
124+
ast.Constant(n=code, kind=None),
125+
old_kwargs_code,
126+
ast.Constant(n=node.lineno, kind=None),
127+
ast.Constant(n=node.col_offset, kind=None), ],
128+
keywords=old_kwargs_nodes_ast), arg=None)
129+
node.keywords = [new_kwargs_node]
130+
131+
@staticmethod
132+
def add_after_call_used_capturing(code, node, subscript):
133+
"""
134+
After a method got executed, capture the return value
135+
"""
136+
instrumented_call_node = ast.Call(func=ast.Name(id='after_call_used', ctx=ast.Load()),
137+
args=[ast.Constant(n=subscript, kind=None),
138+
ast.Constant(n=code, kind=None),
139+
node,
140+
ast.Constant(n=node.lineno, kind=None),
141+
ast.Constant(n=node.col_offset, kind=None)],
142+
keywords=[])
143+
instrumented_call_node = ast.copy_location(instrumented_call_node, node)
29144
return instrumented_call_node
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""
2+
The Vertices used in the WIR as nodes for the networkx.DiGraph
3+
"""
4+
5+
6+
class DagVertex:
7+
"""
8+
A WIR Vertex
9+
"""
10+
11+
def __init__(self, node_id, name, lineno=None, col_offset=None, module=None):
12+
# pylint: disable=too-many-arguments
13+
self.node_id = node_id
14+
self.name = name
15+
self.module = module
16+
self.lineno = lineno
17+
self.col_offset = col_offset
18+
19+
def __repr__(self):
20+
message = "DagVertex(node_id={}: name='{}', module={}, lineno={}, col_offset={})" \
21+
.format(self.node_id, self.name, self.module, self.lineno, self.col_offset)
22+
return message
23+
24+
def __eq__(self, other):
25+
return self.node_id == other.node_id and \
26+
self.name == other.name and \
27+
self.module == other.module and \
28+
self.lineno == other.lineno and \
29+
self.col_offset == other.col_offset
30+
31+
def __hash__(self):
32+
return hash(self.node_id)

0 commit comments

Comments
 (0)