wala / ML

Eclipse Public License 2.0
23 stars 17 forks source link

Can't track tensors involved in operator overloading #135

Open khatchad opened 5 months ago

khatchad commented 5 months ago

Consider the following code:

# From https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/GradientTape#gradient
x = tf.ragged.constant([[1.0, 2.0], [3.0]])
with tf.GradientTape() as g:
  g.watch(x)
  y = x * x # HERE
g.gradient(y, x)

Consider the multiplication expression at "HERE" above. In this context, x is a tensor, and multiplying two tensors yields a tensor. But, I cannot see a way to express that in tensorflow.xml.

Regression

The function tf.multiply() seems to work fine.