beartype / plum

Multiple dispatch in Python
https://beartype.github.io/plum
MIT License
539 stars 24 forks source link

Truncating dispatch error message (types, instead of instances) #125

Open femtomc opened 1 year ago

femtomc commented 1 year ago

Hi!

By default, it seems that plum.resolver.NotFoundLookupError messages are quite large e.g. I'll get errors that look like this:

NotFoundLookupError: For function `update` of `genjax._src.core.datatypes.generative.GenerativeFunction`, `(StaticGenerativeFunction(source=<function _inner at 0x105f94d60>), Traced<ShapedArray(uint32[2])>with<DynamicJaxprTrace(level=1/0)>, StaticTrace(gen_fn=StaticGenerativeFunction(source=<function _inner at 0x105f94d60>), args=(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>,), retval=Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, address_choices=StaticLanguageChoiceMap(addrs=[PytreeConst(const='y1')], subtraces=[DistributionTrace(gen_fn=TFPDistribution(make_distribution=<class 'tensorflow_probability.substrates.jax.distributions.normal.Normal'>), args=(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>), value=Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, score=Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)]), cache=Trie(inner={}), score=Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>), EmptyChoice(), (Diff(primal=Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, tangent=_UnknownChange()),))` could not be resolved.

Now -- while all the information is there -- it's quite hard to parse. Really, what I need to see is just the toplevel types for each argument.

Is there a way to provide a global option to plum to ask for this type of truncation?

wesselb commented 1 year ago

Hey @femtomc! You're totally right that this is hard to parse.

The problem is that the top-level types do not always tell the whole story. For example, if a method accepts Literal[1]s, but the argument is a 2, then you would see that the type int is not matched by Literal[1]. For debugging, it is sometimes helpful to know that the int was a 2. This was the main motivation for printing the whole objects instead of just the top-level type.

However, perhaps we can get the best of both worlds by first printing the top-level types and only appending the full arguments below. I'm thinking of something like

NotFoundLookupError: For function `update` of `genjax._src.core.datatypes.generative.GenerativeFunction`, arguments with types `(A, B)` could not be resolved. The given arguments are `(...)`.

How would something like that look?

femtomc commented 1 year ago

Well, that looks incredible :) Do you need any help from me to support this sort of thing?

Worst comes to worst, I can catch the method exception with a context handler and try to implement this myself library side.

PhilipVinc commented 1 year ago

Technically you could even go through the various signatures to determine whether we need or not to display the type or value, and display it only if needed.

By the way, I'm hoping that @wesselb will merge soon #110 which will greatly improve how errors look, so if he agrees it might be worth it to base a PR on top of that.

wesselb commented 1 year ago

I think building a PR on top of #110 make a whole lot of sense. :) Yes, I'm prioritising #110 and hope to provide a review and merge in the upcoming days!