TuringLang / DynamicPPL.jl

Implementation of domain-specific language (DSL) for dynamic probabilistic programming
https://turinglang.org/DynamicPPL.jl/
MIT License
157 stars 26 forks source link

Type instability in unflatten for medium / large tuples #600

Open willtebbutt opened 4 months ago

willtebbutt commented 4 months ago

MWE:

using DynamicPPL, Test
x = randn(11)
original = (x..., )
@inferred DynamicPPL.unflatten(original, x)

Note that if you change the length of x to 10, this infers correctly.

The use of ntuple the implementation of this method of unflatten, here, means that inference bails out after 10 statements.

You'll also hit another inference bailout at length 32 I believe, when map and cumsum hit their inference heuristics.

Assuming that in both instances you wish to avoid having inference bail out, you'll need to make use of generated functions. I've had to tackle this a fair bit in Tapir, and I've found that an effective strategy is to write a small number of generated functions which are higher-order, and to make use of them throughout. For example, I define a function called tuple_map, which is basically just map, but restricted to tuples, and which forces the compiler to specialise for any length of tuple. You also need cumsum to specialise, so you'll need something more than just a map replacement of course.

I'd be happy to take a punt at this if you'd be interested @torfjelde ?

torfjelde commented 4 months ago

I think this is an issue that requires some discussion (preferably involving @devmotion and @yebai too). We've previously been a bit too aggressive with sprinkling @generated everywhere, so in these more recent additions, e.g. unflatten, we've been somewhat deliberate in using things like ntuple and map to exploding compile times.

Now, whether that is indeed the correct approach, is unclear :upside_down_face: For example, I think maybe just 10 arguments is a bit "too few", while I'd be hard pressed to think of models where I can see us ending up with NamedTuple with more than 32 fields :grimacing:

And yeah, regarding tuple_map, etc. that's definitively the approach we'd want to take if so :+1:

I guess a super-easy approach: make the changes and make a PR, and then see how it affects compile times? :shrug:

willtebbutt commented 4 months ago

Now, whether that is indeed the correct approach, is unclear 🙃 For example, I think maybe just 10 arguments is a bit "too few", while I'd be hard pressed to think of models where I can see us ending up with NamedTuple with more than 32 fields 😬

Yeah, that's fair. I agree it would have to be a rather large model for there to be more than 32 tilde statements. That being said, it is worth asking what it is that we would like to happen when that happens -- do we want inference to fall over, or do we want compilation to take slightly longer?

devmotion commented 4 months ago

My general stance is that NamedTuples, Tuples etc. are just not the right tool for larger collections, and working around this design limitation within a package can be unsuccessful anyway if you start composing the internal methods with external functions since likely they are not optimized for this unintended use case either. I guess, a bit similar to how you wouldn't/shouldn't use StaticArrays for larger arrays.

Of course, the concrete case is slightly different because there might be instances where the default compiler thresholds seem just a tiny bit too low. Even in such cases I try to stay away from @generated functions as much as possible though due to their limitations. I also tend to allow the compiler to use the non-generated alternative by using the if @generated ... pattern instead of @generated function ....

willtebbutt commented 4 months ago

My general stance is that NamedTuples, Tuples etc. are just not the right tool for larger collections

I agree with this in general. However, I'm unaware of another approach to handling largish collections of heterogeneous data in a type stable way. Is there an alternative that I'm unaware of?

Put differently, what should SimpleVarInfo be doing if it wants type stability for models involving more than 31 tilde statements?

and working around this design limitation within a package can be unsuccessful anyway if you start composing the internal methods with external functions since likely they are not optimized for this unintended use case either.

I agree that this is a something you have to keep on top of.

Of course, the concrete case is slightly different because there might be instances where the default compiler thresholds seem just a tiny bit too low. Even in such cases I try to stay away from @generated functions as much as possible though due to their limitations.

How should we get around using generated functions in such cases? Again, I might just be missing a good approach here...

sunxd3 commented 4 months ago

Naive question, what is if @generated ... pattern?

willtebbutt commented 4 months ago

Naive question, what is if @generated ... pattern?

See here :)