wala / ML

Eclipse Public License 2.0
25 stars 17 forks source link

Workaround missing calls to tf.keras.Model.call() #120

Closed khatchad closed 9 months ago

khatchad commented 10 months ago

Workaround #106. The real fix should use the XML summaries.

khatchad commented 10 months ago

The real fix is blocked on #107.

khatchad commented 10 months ago

Sorry, this might be DOA. Investigating.

khatchad commented 10 months ago

We're good.

khatchad commented 10 months ago

108 adds support for __call__(), but there's another case. A keras model can also override call() for the same purpose. Then, when the model object is used as a callable, that dispatches to the keras super class __call__(), which then invokes call() in the subclass.

We need inheritance support to analyze this case. But, we don't have it yet (I believe that is #107). This PR adds a simple (perhaps too simple?) workaround for this case by treating call() similarly to __call__(). It works, but this PR may add spurious edges to the call graph for classes that implement call() and whose objects are used as callables but do not inherit from tf.keras.Model. I would estimate that there are not many of those cases, however.

@tatianacv Do I have this right? Can you elaborate?

tatianacv commented 10 months ago

108 adds support for __call__(), but there's another case. A keras model can also override call() for the same purpose. Then, when the model object is used as a callable, that dispatches to the keras super class __call__(), which then invokes call() in the subclass.

We need inheritance support to analyze this case. But, we don't have it yet (I believe that is #107). This PR adds a simple (perhaps too simple?) workaround for this case by treating call() similarly to __call__(). It works, but this PR may add spurious edges to the call graph for classes that implement call() and whose objects are used as callables but do not inherit from tf.keras.Model. I would estimate that there are not many of those cases, however.

@tatianacv Do I have this right? Can you elaborate?

Yes, as seen in the TF docs (https://www.tensorflow.org/api_docs/python/tf/keras/Model#call) The method call from tf.keras.Model should not be called directly. It is only meant to be overridden when subclassing tf.keras.Model. To call a model on an input, always use the call() method, i.e. model(inputs), which relies on the underlying call() method.

msridhar commented 10 months ago

This does seem a bit hacky, but perhaps it is the best solution for the moment. If this is blocking you all, and we don't think we can get inheritance working in the short term, I'm ok with this change, but let's add a comment indicating we should remove this code if/when inheritance is supported.