parrt / tensor-sensor

The goal of this library is to generate more helpful exception messages for matrix algebra expressions for numpy, pytorch, jax, tensorflow, keras, fastai.
https://github.com/parrt/tensor-sensor
MIT License
794 stars 40 forks source link

Unhandled statements cause exceptions (Was: Nested calls to clarify can raise stacked Exceptions) #15

Closed clefourrier closed 3 years ago

clefourrier commented 4 years ago

Hello,

I created a decorator to call clarify around the forward function of my custom Pytorch models (derived from torch.nn.Module).

Said decorator looks like this:

def clarify(function: callable) -> callable:
    """ Clarify decorator."""

    def call_clarify(*args, **kwargs):
        with tsensor.clarify(fontname="DejaVu Sans"):
            return function(*args, **kwargs)

    return call_clarify

When doing machine learning using Pytorch, models (derived from torch.nn.Module) can sometimes be "stacked". In a translation task, an EncoderDecoder's forward will call its Decoder's forward, itself calling the forward of an Attention module, for example.

In such a case, this results in nested clarify calls, which raise a succession of Exceptions, because some of the topmost clarify function do not exit correctly. To be more specific, l.124 of analysis.py, self.view can be None, which then raises an Exception on self.view.show().

A quick fix (that I did in local) was adding a check line 131:

                if self.view:
                    if self.show=='viz':
                        self.view.show()
                    augment_exception(exc_value, self.view.offending_expr)

However, I am not sure this would be the best fix possible, as I am not sure whether that is a common problem or not and how/if this is intended to be fixed. What do you think?

parrt commented 4 years ago

oh wow. I hadn't thought of this but yes maybe there is an issue there. It seems like the deepest exception should be reported and then swallowed by the outer clarify, correct?

clefourrier commented 4 years ago

Yes, exactly! We want to report only the deepest exception, and none of the following clarify's ones.

parrt commented 3 years ago

Here is a test case:

# Test for https://github.com/parrt/tensor-sensor/issues/15
# Nested clarify's and a deep one catches exception

import tsensor
import numpy as np

def f():
    np.ones(1) @ np.ones(2)

def A():
    with tsensor.clarify():
        f()

def B():
    with tsensor.clarify():
        A()

B()

I get this double clarification:

Traceback (most recent call last):
  File "/Users/parrt/github/tensor-sensor/testing/nested.py", line 18, in <module>
    B()
  File "/Users/parrt/github/tensor-sensor/testing/nested.py", line 16, in B
    A()
  File "/Users/parrt/github/tensor-sensor/testing/nested.py", line 12, in A
    f()
  File "/Users/parrt/github/tensor-sensor/testing/nested.py", line 8, in f
    np.ones(1) @ np.ones(2)
ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 2 is different from 1)
Cause: @ on tensor operand np.ones(1) w/shape (1,) and operand np.ones(2) w/shape (2,)
Cause: @ on tensor operand np.ones(1) w/shape (1,) and operand np.ones(2) w/shape (2,)

I think you are having a different problem which is that the code where the exception happens can't be processed. Could it be that you are seeing two problems at once? first is nested clarifications and the other is I don't check for statements I can't display.