wala / ML

Eclipse Public License 2.0
23 stars 17 forks source link

Not tracking tensors returned by `tf.reshape()` for data sources other than MNIST #195

Closed khatchad closed 1 month ago

khatchad commented 1 month ago

Consider the following code:

# tf2_test_reshape.py
# From https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/reshape

import tensorflow as tf

def f(a):
    pass

t1 = tf.ones([2, 3])
t2 = tf.reshape(t1, [6])
f(t2)

t2 should be a (reshaped) tensor, and the argument to f() should also be tracked as a tensor. Instead, I'm seeing this tensor analysis result:

[INFO] Tensor analysis: answer:
[Node: synthetic < PythonLoader, Ltensorflow/functions/reshape, do()LRoot; > Context: CallStringContext: [ script tf2_test_reshape.py.do()LRoot;@105 ], v2][{[D:Symbolic,n, D:Compound,[D:Constant,28, D:Constant,28]] of pixel}]
[Node: <Code body of function Lscript tf2_test_reshape.py> Context: CallStringContext: [ com.ibm.wala.FakeRootClass.fakeRootMethod()V@2 ], v245][{[D:Symbolic,n, D:Compound,[D:Constant,28, D:Constant,28]] of pixel}]
[Ret-V:Node: synthetic < PythonLoader, Ltensorflow/functions/ones, do()LRoot; > Context: CallStringContext: [ script tf2_test_reshape.py.do()LRoot;@100 ]][{[D:Symbolic,n, D:Compound,[D:Constant,28, D:Constant,28]] of pixel}]
[Node: synthetic < PythonLoader, Ltensorflow/functions/ones, do()LRoot; > Context: CallStringContext: [ script tf2_test_reshape.py.do()LRoot;@100 ], v5][{[D:Symbolic,n, D:Compound,[D:Constant,28, D:Constant,28]] of pixel}]

In the IR, v245 refers to the return value of tf.ones(). That's the only tensor in this file.

Regression

khatchad commented 1 month ago

Add this test works:

https://github.com/ponder-lab/ML/commit/2db36e27afd9399c70a788a6800b675c19505379

I believe the problem is that the data sources are hard-coded:

https://github.com/wala/ML/blob/ddba21e7881f2a9cc825f1857aea5a5ea89f1bc3/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java#L606-L610

khatchad commented 1 month ago

But, it could also have something to do with the way other APIs are being constructed. For example, the points-to set for tf.ones() is empty.

khatchad commented 1 month ago

Looking at the summary of tf.reshape(), I see that there's a data copy:

https://github.com/wala/ML/blob/ddba21e7881f2a9cc825f1857aea5a5ea89f1bc3/com.ibm.wala.cast.python.ml/data/tensorflow.xml#L379-L388

That may mean if the data source has something wrong with it, any copied data would also have the problem. Thus, the problem may not be with the tf.reshape() operation itself but rather with how data sources other than MNIST are constructed.

khatchad commented 1 month ago

That being said, copy_data() in the above summary doesn't use its argument.

khatchad commented 1 month ago

Thus, my best guess is that the problem involves a combination of the (new) XML summaries and the hard-coded initialization of the dataflow.