patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.18k stars 62 forks source link

IPython `inspect.getsource()` failure due to incorrect co_firstlineno #160

Open davideger opened 10 months ago

davideger commented 10 months ago

This colab shows an unexpected side effect of enabling automatic jaxtype checking in IPython: It causes inspect.getsource's to retrieve incorrect source text for a given function.

That is, if I run these two cells:

def where(q, a, b):
    "Use this function to replace an if-statement."
    return (q * a) + (~q) * b

def arange(i: int) -> jaxtyping.Int32[torch.Tensor, "i"]:
    "Use this function to replace a for-loop."
    return torch.tensor(range(i))

def ones(i: int) -> jaxtyping.Int32[torch.Tensor, "{i}"]:
    return arange(i) - arange(i) + 1
import inspect

inspect.getsource(ones)

Then I get:

'def ones(i: int) -> jaxtyping.Int32[torch.Tensor, "{i}"]:\n    return arange(i) - arange(i) + 1\n'

But if I turn on type checking by using:

%load_ext jaxtyping
%jaxtyping.typechecker beartype.beartype

Then running the same inspect.getsource(ones) yields:

'def where(q, a, b):\n    "Use this function to replace an if-statement."\n    return (q * a) + (~q) * b\n'
davideger commented 10 months ago

I guess some context might help motivate this one.

I was converting Sasha Rush's tensor puzzles from torchtyping to jaxtyping, and the final cell of his exercise is a competition to make the "shortest one liner" to accomplish some standard torch function. After I finished the port, I noticed this "score board" wasn't working properly, and tracked it down to the above interaction.

patrick-kidger commented 10 months ago

Hmm, that's a little odd! Digging into this a little, it seems like this is due to

inspect.unwrap(ones).__code__.co_firstlineno

(which is used inside inspect.getsource) returning a 1, rather than the true value.

Moreover I've checked and this doesn't seem to happen in normal usage; this is specific to the IPython extension.

Tagging @knyazer -- what do you think might be going on -- something inside IPython itself maybe? Our AST transformations don't really do much beyond appending to the decorator list, after all.

knyazer commented 10 months ago

That's a very nice problem. After a bit of time, I came to the same conclusion, that the root cause is that co_firstlineno is determined incorrectly, while most other properties of the function object are correct. Incorrectness of co_firstlineno actually leads to the whole __code__ being incorrect.

First of all, there is a quick fix that I can propose for @davideger: instead of having one big cell, have three separate cells, like here (colab link). Another solution is to hand-annotate the functions instead of using the magic.

But this does not solve the issue, it just avoids it. So let's dig further...

I tracked down the issue to the level of the IPython magic: everything breaks if we attempt to add any decorator to the ast transformations list. The most interesting part is that adding even the pointer-copying decorator fails, even if wrapped with functools.wraps or wrapt. Look at this colab, for example. The same incorrect behavior, without jaxtyping!

In the end, I only have a hypothesis about what is wrong (that is, the IPython wraps the magics incorrectly, but I am likely to be wrong), and I would prefer to monkey-patch it on the level of IPython anyways, not jaxtyping. So I think we should transfer this issue to the IPython issue tracker.

davideger commented 10 months ago

Thanks for tracking down that this is not jaxtyping specific @knyazer, it sounds like IPython is the right place for this issue.