wala / ML

Eclipse Public License 2.0
25 stars 17 forks source link

Missing custom decorators with decorator parameters #189

Closed khatchad closed 5 months ago

khatchad commented 5 months ago

Related to https://github.com/wala/ML/issues/188.

Consider the following code:

import tensorflow as tf

def mama(test=None):
    assert test == "Hello"

    def _mama(func):

        def core(*args, **kwargs):
            assert isinstance(args[0], tf.Tensor)
            return func(*args, **kwargs)

        return core

    return _mama

@mama(test="Hello")
def f(x):
    assert isinstance(x, tf.Tensor)
    return 5

res = f(tf.constant(1))
assert res == 5

The above seems to be a common pattern for decorators that take arguments and whose decoratorated functions also take arguments. To do this, you nest the functions three levels deep. What happens in Ariadne, however, is that we only "peel" back two of these. The following IR is produced from the above code on 644c562f26f6ce5243dba770786c799986177650:

callees of node Lscript tf2_test_decorated_method12.py : [import, mama, constant, _mama]

IR of node 2, context CallStringContext: [ com.ibm.wala.FakeRootClass.fakeRootMethod()V@2 ]
<Code body of function Lscript tf2_test_decorated_method12.py>
CFG:
BB0[-1..-2]
    -> BB1
BB1[0..111]
    -> BB2
    -> BB4
BB2[112..112]
    -> BB3
    -> BB4
BB3[113..116]
    -> BB4
BB4[-1..-2]
Instructions:
BB0
BB1
0   global:global script tf2_test_decorated_method12.py = v1<no information>
1   v5 = new <PythonLoader,Lwala/builtin/enumerate>@1<no information> [5=[enumerate]]
2   v6 = new <PythonLoader,Lwala/builtin/int>@2<no information> [6=[int]]
3   v7 = new <PythonLoader,Lwala/builtin/round>@3<no information> [7=[round]]
4   v8 = new <PythonLoader,Lwala/builtin/len>@4<no information> [8=[len]]
5   v9 = new <PythonLoader,Lwala/builtin/list>@5<no information> [9=[list]]
6   v10 = new <PythonLoader,Lwala/builtin/range>@6<no information> [10=[range]]
7   v11 = new <PythonLoader,Lwala/builtin/sorted>@7<no information> [11=[sorted]]
8   v12 = new <PythonLoader,Lwala/builtin/str>@8<no information> [12=[str]]
9   v13 = new <PythonLoader,Lwala/builtin/sum>@9<no information> [13=[sum]]
10   v14 = new <PythonLoader,Lwala/builtin/type>@10<no information> [14=[type]]
11   v15 = new <PythonLoader,Lwala/builtin/zip>@11<no information> [15=[zip]]
12   v16 = new <PythonLoader,Lwala/builtin/slice>@12<no information> [16=[slice]]
13   v17 = new <PythonLoader,Lwala/builtin/__delete__>@13<no information> [17=[__delete__]]
14   v3 = new <PythonLoader,Lwala/builtin/print>@14<no information> [3=[print]]
15   v19 = new <PythonLoader,Lwala/builtin/iter>@15<no information> [19=[iter]]
16   v20 = new <PythonLoader,Lwala/builtin/next>@16<no information> [20=[next]]
17   v21 = new <PythonLoader,Lwala/builtin/isinstance>@17<no information> [21=[isinstance]]
18   lexical:isinstance@Lscript tf2_test_decorated_method12.py = v21<no information> [21=[isinstance]]
21   v24 = invokestatic < PythonLoader, LBaseException, import()LBaseException; > @21 exception:v25<no information> [24=[BaseException]]
22   v27 = invokestatic < PythonLoader, LDeprecationWarning, import()LDeprecationWarning; > @22 exception:v28<no information> [27=[DeprecationWarning]]
23   v30 = invokestatic < PythonLoader, LException, import()LException; > @23 exception:v31<no information> [30=[Exception]]
24   v33 = invokestatic < PythonLoader, LFutureWarning, import()LFutureWarning; > @24 exception:v34<no information> [33=[FutureWarning]]
25   v36 = invokestatic < PythonLoader, LNameError, import()LNameError; > @25 exception:v37<no information> [36=[NameError]]
26   v39 = invokestatic < PythonLoader, LNone, import()LNone; > @26 exception:v40<no information> [39=[None]]
27   v42 = invokestatic < PythonLoader, LRuntimeError, import()LRuntimeError; > @27 exception:v43<no information> [42=[RuntimeError]]
28   v45 = invokestatic < PythonLoader, LStopIteration, import()LStopIteration; > @28 exception:v46<no information> [45=[StopIteration]]
29   v48 = invokestatic < PythonLoader, LTypeError, import()LTypeError; > @29 exception:v49<no information> [48=[TypeError]]
30   v51 = invokestatic < PythonLoader, LUserWarning, import()LUserWarning; > @30 exception:v52<no information> [51=[UserWarning]]
31   v54 = invokestatic < PythonLoader, LValueError, import()LValueError; > @31 exception:v55<no information> [54=[ValueError]]
32   v57 = invokestatic < PythonLoader, L__doc__, import()L__doc__; > @32 exception:v58<no information> [57=[__doc__]]
33   v60 = invokestatic < PythonLoader, L__file__, import()L__file__; > @33 exception:v61<no information> [60=[__file__]]
34   v63 = invokestatic < PythonLoader, L__name__, import()L__name__; > @34 exception:v64<no information> [63=[__name__]]
35   v66 = invokestatic < PythonLoader, Labs, import()Labs; > @35 exception:v67<no information> [66=[abs]]
36   v69 = invokestatic < PythonLoader, Lall, import()Lall; > @36 exception:v70<no information> [69=[all]]
37   v72 = invokestatic < PythonLoader, Lany, import()Lany; > @37 exception:v73<no information> [72=[any]]
38   v75 = invokestatic < PythonLoader, Lbin, import()Lbin; > @38 exception:v76<no information> [75=[bin]]
39   v78 = invokestatic < PythonLoader, Lbool, import()Lbool; > @39 exception:v79<no information> [78=[bool]]
40   v81 = invokestatic < PythonLoader, Lbytes, import()Lbytes; > @40 exception:v82<no information> [81=[bytes]]
41   v84 = invokestatic < PythonLoader, Lcallable, import()Lcallable; > @41 exception:v85<no information> [84=[callable]]
42   v87 = invokestatic < PythonLoader, Lchr, import()Lchr; > @42 exception:v88<no information> [87=[chr]]
43   v90 = invokestatic < PythonLoader, Lcomplex, import()Lcomplex; > @43 exception:v91<no information> [90=[complex]]
44   v93 = invokestatic < PythonLoader, Ldel, import()Ldel; > @44 exception:v94<no information> [93=[del]]
45   v96 = invokestatic < PythonLoader, Ldict, import()Ldict; > @45 exception:v97<no information> [96=[dict]]
46   v99 = invokestatic < PythonLoader, Ldir, import()Ldir; > @46 exception:v100<no information> [99=[dir]]
47   v102 = invokestatic < PythonLoader, Ldivmod, import()Ldivmod; > @47 exception:v103<no information> [102=[divmod]]
48   v105 = invokestatic < PythonLoader, Leval, import()Leval; > @48 exception:v106<no information> [105=[eval]]
49   v108 = invokestatic < PythonLoader, Lexec, import()Lexec; > @49 exception:v109<no information> [108=[exec]]
50   v111 = invokestatic < PythonLoader, Lexit, import()Lexit; > @50 exception:v112<no information> [111=[exit]]
51   v114 = invokestatic < PythonLoader, Lfilter, import()Lfilter; > @51 exception:v115<no information> [114=[filter]]
52   v117 = invokestatic < PythonLoader, Lfloat, import()Lfloat; > @52 exception:v118<no information> [117=[float]]
53   v120 = invokestatic < PythonLoader, Lformat, import()Lformat; > @53 exception:v121<no information> [120=[format]]
54   v123 = invokestatic < PythonLoader, Lfrozenset, import()Lfrozenset; > @54 exception:v124<no information> [123=[frozenset]]
55   v126 = invokestatic < PythonLoader, Lget_ipython, import()Lget_ipython; > @55 exception:v127<no information> [126=[get_ipython]]
56   v129 = invokestatic < PythonLoader, Lgetattr, import()Lgetattr; > @56 exception:v130<no information> [129=[getattr]]
57   v132 = invokestatic < PythonLoader, Lglobals, import()Lglobals; > @57 exception:v133<no information> [132=[globals]]
58   v135 = invokestatic < PythonLoader, Lhasattr, import()Lhasattr; > @58 exception:v136<no information> [135=[hasattr]]
59   v138 = invokestatic < PythonLoader, Lhelp, import()Lhelp; > @59 exception:v139<no information> [138=[help]]
60   v141 = invokestatic < PythonLoader, Lhex, import()Lhex; > @60 exception:v142<no information> [141=[hex]]
61   v144 = invokestatic < PythonLoader, Lid, import()Lid; > @61 exception:v145<no information> [144=[id]]
62   v147 = invokestatic < PythonLoader, Linput, import()Linput; > @62 exception:v148<no information> [147=[input]]
63   v150 = invokestatic < PythonLoader, Lisinstance, import()Lisinstance; > @63 exception:v151<no information> [150=[isinstance]]
64   lexical:isinstance@Lscript tf2_test_decorated_method12.py = v150<no information> [150=[isinstance]]
67   v153 = invokestatic < PythonLoader, Llocals, import()Llocals; > @67 exception:v154<no information> [153=[locals]]
68   v156 = invokestatic < PythonLoader, Lmap, import()Lmap; > @68 exception:v157<no information> [156=[map]]
69   v159 = invokestatic < PythonLoader, Lmax, import()Lmax; > @69 exception:v160<no information> [159=[max]]
70   v162 = invokestatic < PythonLoader, Lmin, import()Lmin; > @70 exception:v163<no information> [162=[min]]
71   v165 = invokestatic < PythonLoader, Lobject, import()Lobject; > @71 exception:v166<no information> [165=[object]]
72   v168 = invokestatic < PythonLoader, Lopen, import()Lopen; > @72 exception:v169<no information> [168=[open]]
73   v171 = invokestatic < PythonLoader, Lord, import()Lord; > @73 exception:v172<no information> [171=[ord]]
74   v174 = invokestatic < PythonLoader, Lpow, import()Lpow; > @74 exception:v175<no information> [174=[pow]]
75   v177 = invokestatic < PythonLoader, Lprint, import()Lprint; > @75 exception:v178<no information> [177=[print]]
77   v180 = invokestatic < PythonLoader, Lproperty, import()Lproperty; > @77 exception:v181<no information> [180=[property]]
78   v183 = invokestatic < PythonLoader, Lrepr, import()Lrepr; > @78 exception:v184<no information> [183=[repr]]
79   v186 = invokestatic < PythonLoader, Lreversed, import()Lreversed; > @79 exception:v187<no information> [186=[reversed]]
80   v189 = invokestatic < PythonLoader, Lset, import()Lset; > @80 exception:v190<no information> [189=[set]]
81   v192 = invokestatic < PythonLoader, Lsuper, import()Lsuper; > @81 exception:v193<no information> [192=[super]]
82   v195 = invokestatic < PythonLoader, Ltuple, import()Ltuple; > @82 exception:v196<no information> [195=[tuple]]
83   v198 = invokestatic < PythonLoader, Lvars, import()Lvars; > @83 exception:v199<no information> [198=[vars]]
84   v201 = invokestatic < PythonLoader, LNotImplementedError, import()LNotImplementedError; > @84 exception:v202<no information> [201=[NotImplementedError]]
85   v204 = invokestatic < PythonLoader, LWarning, import()LWarning; > @85 exception:v205<no information> [204=[Warning]]
86   v207 = invokestatic < PythonLoader, Lcd, import()Lcd; > @86 exception:v208<no information> [207=[cd]]
87   v210 = invokestatic < PythonLoader, Lclear, import()Lclear; > @87 exception:v211<no information> [210=[clear]]
88   v213 = invokestatic < PythonLoader, Lpylab, import()Lpylab; > @88 exception:v214<no information> [213=[pylab]]
89   v216 = invokestatic < PythonLoader, LRuntimeWarning, import()LRuntimeWarning; > @89 exception:v217<no information> [216=[RuntimeWarning]]
90   v219 = invokestatic < PythonLoader, Lhist, import()Lhist; > @90 exception:v220<no information> [219=[hist]]
91   v222 = invokestatic < PythonLoader, Lmatplotlib, import()Lmatplotlib; > @91 exception:v223<no information> [222=[matplotlib]]
92   v225 = invokestatic < PythonLoader, Lrecall, import()Lrecall; > @92 exception:v226<no information> [225=[recall]]
93   v228 = invokestatic < PythonLoader, Lhistory, import()Lhistory; > @93 exception:v229<no information> [228=[history]]
94   v231 = invokestatic < PythonLoader, Ltime, import()Ltime; > @94 exception:v232<no information> [231=[time]]
95   v234 = invokestatic < PythonLoader, LKeyError, import()LKeyError; > @95 exception:v235<no information> [234=[KeyError]]
96   v237 = invokestatic < PythonLoader, Ldisplay, import()Ldisplay; > @96 exception:v238<no information> [237=[display]]
97   v240 = invokestatic < PythonLoader, Ltensorflow, import()Ltensorflow; > @97 exception:v241tf2_test_decorated_method12.py [3:0] -> [3:23] [240=[tf]]
98   lexical:tf@Lscript tf2_test_decorated_method12.py = v240tf2_test_decorated_method12.py [1:0] -> [1:0] [240=[tf]]
101   global:global mama_default_0 = v242:#nulltf2_test_decorated_method12.py [3:0] -> [27:15]
102   v246 = new <PythonLoader,Lscript tf2_test_decorated_method12.py/mama>@102<no information> [246=[mama]]
103   global:global script tf2_test_decorated_method12.py/mama = v246<no information> [246=[mama]]
104   putfield v1.< PythonLoader, LRoot, mama, <PythonLoader,LRoot> > = v246<no information> [246=[mama]]
105   global:global Lscript tf2_test_decorated_method12.py/mama_defaults_1 = v242:#nulltf2_test_decorated_method12.py [6:0] -> [17:16]
106   v4 = new <PythonLoader,Lscript tf2_test_decorated_method12.py/f>@106<no information> [4=[f]]
107   v18 = invokeFunction < PythonLoader, LCodeBody, do()LRoot; > v246,v4 @107 exception:v251tf2_test_decorated_method12.py [3:0] -> [27:15] [18=[f]246=[mama]4=[f]]
108   global:global script tf2_test_decorated_method12.py/f = v18<no information> [18=[f]]
109   putfield v1.< PythonLoader, LRoot, f, <PythonLoader,LRoot> > = v18<no information> [18=[f]]
110   v255 = fieldref v240.v256:#constant    tf2_test_decorated_method12.py [26:8] -> [26:19] [240=[tf]]
111   v254 = invokeFunction < PythonLoader, LCodeBody, do()LRoot; > v255,v257:#1 @111 exception:v258tf2_test_decorated_method12.py [26:8] -> [26:22]
BB2
112   v253 = invokeFunction < PythonLoader, LCodeBody, do()LRoot; > v18,v254 @112 exception:v259tf2_test_decorated_method12.py [26:6] -> [26:23] [253=[res]18=[f]]
BB3
115   v262 = binaryop(eq) v253 , v260:#5     tf2_test_decorated_method12.py [27:7] -> [27:15] [253=[res]260=[cmp0]]
116   assert v262 (fromSpec: true)           tf2_test_decorated_method12.py [27:0] -> [27:15]
BB4

The invoked function at the very end if actually not f() but the middle function:

[Node: <Code body of function Lscript tf2_test_decorated_method12.py> Context: CallStringContext: [ com.ibm.wala.FakeRootClass.fakeRootMethod()V@2 ], v253] --> [SMIK:SITE_IN_NODE{<Code body of function Lscript tf2_test_decorated_method12.py/mama/_mama>:Lscript tf2_test_decorated_method12.py/mama/_mama/core in CallStringContext: [ script tf2_test_decorated_method12.py.do()LRoot;@112 ]}@creator:Node: <Code body of function Lscript tf2_test_decorated_method12.py/mama/_mama> Context: CallStringContext: [ script tf2_test_decorated_method12.py.do()LRoot;@112 ]]

The core() function should be called, but it's not:

callees of node _mama : []

IR of node 6, context CallStringContext: [ script tf2_test_decorated_method12.py.do()LRoot;@112 ]
<Code body of function Lscript tf2_test_decorated_method12.py/mama/_mama>
CFG:
BB0[-1..-2]
    -> BB1
BB1[0..4]
    -> BB2
BB2[-1..-2]
Instructions:
BB0
BB1
0   lexical:func@Lscript tf2_test_decorated_method12.py/mama/_mama = v2tf2_test_decorated_method12.py [9:4] -> [17:3] [2=[func]]
1   v5 = new <PythonLoader,Lscript tf2_test_decorated_method12.py/mama/_mama/core>@1<no information> [5=[core]]
2   global:global script tf2_test_decorated_method12.py/mama/_mama/core = v5<no information> [5=[core]]
3   putfield v1.< PythonLoader, LRoot, core, <PythonLoader,LRoot> > = v5<no information> [1=[the function]5=[core]]
4   return v5                                tf2_test_decorated_method12.py [15:8] -> [15:19] [5=[core]]
BB2

The points-to analysis:

[Node: <Code body of function Lscript tf2_test_decorated_method12.py/mama/_mama> Context: CallStringContext: [ script tf2_test_decorated_method12.py.do()LRoot;@112 ], v5] --> [SMIK:SITE_IN_NODE{<Code body of function Lscript tf2_test_decorated_method12.py/mama/_mama>:Lscript tf2_test_decorated_method12.py/mama/_mama/core in CallStringContext: [ script tf2_test_decorated_method12.py.do()LRoot;@112 ]}@creator:Node: <Code body of function Lscript tf2_test_decorated_method12.py/mama/_mama> Context: CallStringContext: [ script tf2_test_decorated_method12.py.do()LRoot;@112 ]]

I see core() there, but's not invoked. Looks to me like need one more function invocation here.

khatchad commented 5 months ago

If I revert 1c88d8682ec042f65ce63ded26b24462ac282b06, now I see core().

khatchad commented 5 months ago

But the problem is that for simpler decorators, visiting the annotation in the IR translator invoke the decorator function without sending a parameter. We basically get null for the function to invoke.

khatchad commented 5 months ago

Since we can't expand *args and **kwargs anyway, it make sense to leave this broken for now and go with the simpler case.

khatchad commented 5 months ago

On second thought, I don't think it makes a ton of sense just to swap problems here. I reverted the change in 2a6e5226db8ffb9aa8a4fd5dde637a452b55bbb2,