parrt / tensor-sensor

The goal of this library is to generate more helpful exception messages for matrix algebra expressions for numpy, pytorch, jax, tensorflow, keras, fastai.
https://github.com/parrt/tensor-sensor
MIT License
771 stars 39 forks source link

Identify nested clarify() calls and ignored nested ones. #18

Closed parrt closed 3 years ago

parrt commented 3 years ago

This is one of the issues at https://github.com/parrt/tensor-sensor/issues/15

parrt commented 3 years ago

File testing/nested.py:

import tsensor
import numpy as np

def f():
    np.ones(1) @ np.ones(2)

def A():
    with tsensor.clarify():
        f()

def B():
    with tsensor.clarify():
        A()

B()

Was displaying viz twice and augmenting twice:

ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 2 is different from 1)
Cause: @ on tensor operand np.ones(1) w/shape (1,) and operand np.ones(2) w/shape (2,)
Cause: @ on tensor operand np.ones(1) w/shape (1,) and operand np.ones(2) w/shape (2,)