JuliaLang / julia

The Julia Programming Language
https://julialang.org/
MIT License
45.54k stars 5.47k forks source link

Index sets and constant propagation on Array shapes #43642

Open ChrisRackauckas opened 2 years ago

ChrisRackauckas commented 2 years ago

It's not too uncommon for type computations to be determined by the size of arrays. One case is ForwardDiff chunksize computations. Another are the algorithm defaults in DifferentialEquations.jl. While these do not necessarily effect runtime performance because the computations lie behind function barriers, they seem to greatly effect compile-time performance and thus were the impetus for compile-time features such as https://github.com/JuliaLang/julia/pull/43370.

Such usage of max_methods=1 and function barriers could be completed eliminated if constant propagation could perform shape inference, moving chunk size computations to compile time. Even further, it has been shown that having index set information as part of the type can be very powerful for optimizations, parallelization, and automatic differentiation (https://arxiv.org/abs/2104.05372). That said, extending the type parameters of an Array would be breaking and thus require a v2.0, while adding constant propagation to array shape calculations would not necessarily be and thus I would probably be leaning towards the latter. Of course, it could be possible that shape calculations are so much more expensive that it actually hurts the compile time in that case as well.

But there's probably some other places where shape inference could help improve GC performance, make odd compilation targets easier for shape-inferable code (like static compilation), etc. so I was curious to see whether this would/could be something done as part of the language itself or whether this would be left to AbstractInterpreter plugins.

oscardssmith commented 2 years ago

Isn't the proper way to do this just to make a new StridedArray subtype?

ChrisRackauckas commented 2 years ago

It would defeat the purpose if it doesn't apply to a standard user code. That could be for prototyping but it doesn't solve the real issue.

oscardssmith commented 2 years ago

It would, at the very least let you know what the effect on compile time was.

jpsamaroo commented 2 years ago

I believe the start of this is https://github.com/JuliaLang/julia/pull/43487 (although I'm not sure if it calculates total shape; it seems like it might just be ndims for now). From there, it's a matter of ensuring that functions like similar, and operations like array slicing, also propagate the sizes of their returns, which is probably rather simple to determine by forwarding allocation shape to length/size calls (which are often used to determine the size of the new output array to allocate).

vchuravy commented 2 years ago

@pchintalapudi and I were discussing this. The reason we are doing it on the LLVM level is that we didn't want to introduce a new lattice type. Effectively the specialness of Array is working against it here. If it were just a normal struct we would get shape propagation for "free" with PartialConst.

AriMKatz commented 2 years ago

make odd compilation targets easier for shape-inferable code

I think particularly for TPUs

@pchintalapudi and I were discussing this. The reason we are doing it on the LLVM level is that we didn't want to introduce a new lattice type. Effectively the specialness of Array is working against it here. If it were just a normal struct we would get shape propagation for "free" with PartialConst.

What's the downside of having a new lattice type?

If we have https://github.com/JuliaLang/julia/issues/40992 then things like that could be explored in user-space.

Is that where https://github.com/JuliaLang/julia/pull/42596 gets us @aviatesk ?

  1. experiment new lattice design with adding new lattice properties (maybe as another PRs)
pchintalapudi commented 2 years ago

I believe the start of this is #43487 (although I'm not sure if it calculates total shape; it seems like it might just be ndims for now)

43487 calculates and pushes individual dimensions of non-escaping array allocations as well as the overall array length to consumers of those values, in addition to a few other array struct fields (maxsize and offset for 1d arrays). There are also a couple of codegen enhancements in that PR to avoid generating any of those accesses for 2d and 3d arrays. However, all of them run after the Julia-level optimization passes and strictly operate on LLVM IR, so the PR won't help inference prior to codegen, nor will it enhance compilation of targets that don't pass through the Julia-LLVM optimization pipeline.

AriMKatz commented 2 years ago

This would be helpful for escape analysis:

When compared to object field analysis, where an access to object field can be analyzed trivially using type information derived by inference, array dimension isn't encoded as type information and so we need an additional analysis to derive that information. EscapeAnalysis at this moment first does an additional simple linear scan to analyze dimensions of allocated arrays before firing up the main analysis routine so that the succeeding escape analysis can precisely analyze operations on those arrays.

However, such precise "per-element" alias analysis is often hard. Essentially, the main difficulty inherit to array is that array dimension and index are often non-constant:

loop often produces loop-variant, non-constant array indices (specific to vectors) array resizing changes array dimension and invalidates its constant-ness

In order to address these difficulties, we need inference to be aware of array dimensions and propagate array dimensions in a flow-sensitive way[ArrayDimension], as well as come up with nice representation of loop-variant values.

EscapeAnalysis at this moment quickly switches to the more imprecise analysis that doesn't track precise index information in cases when array dimensions or indices are trivially non constant.

In dex, index sets go much much further than having concrete shape information for static arrays in the type, so not just analogous to StaticArrays.jl. See more here: https://www.youtube.com/watch?v=npDCCVIaSVQ&t=2588s

Tokazama commented 2 years ago

Has anyone proposed types for index sets yet? Even if Array isn't a normal struct we can probably patch it into some system that everyone else can more easily conform to.

AriMKatz commented 2 years ago

@Tokazama Can Julia's type system can express index sets (and if it can, without heavy specialization cost)? ...Dex has a very particular design that relies on function types, sum types and the duality between memoized functions and arrays, among other things. It also does algebraic code rewriting.

Check out the examples here: https://discourse.julialang.org/t/is-it-ever-logical-to-have-arrays-without-indexing-diag-a-seems-to-be-such-a-case-logical-conclusion-of-generic-programming/81471/5

In particular

'Here is the helper that builds the dynamic program table. Dex's flexible index sets let us encode the fact that the table is 1 larger in each dimension than the inputs. By capturing the relationship statically we avoid both programmer off-by-one errors and runtime array bounds checks.

def levenshtein_table
    {n m a}
    [Eq a]
    (xs: n=>a) (ys: m=>a)
    : (Post n => Post m => Int) =
  yield_state (for _ _. -1) \tab.
    for i:(Post n). tab!i!first_ix := ordinal i
    for j:(Post m). tab!first_ix!j := ordinal j
    for i:n j:m.
      subst_cost = if xs.i == ys.j then 0 else 1
      d_subst  = get tab!(left_post  i)!(left_post  j) + subst_cost
      d_delete = get tab!(left_post  i)!(right_post j) + 1
      d_insert = get tab!(right_post i)!(left_post  j) + 1
      tab!(right_post i)!(right_post j) :=
        minimum [d_subst, d_delete, d_insert]

image

also

As of https://github.com/google-research/dex-lang/pull/876, tables themselves can be used as index sets, letting us build index sets whose dimension is determined at runtime, but still tracked by the compiler. For instance, the type(Fin D)=>letter represents the set of all D-letter strings. And the table for i:((Fin D)=>letter). i instantiates all those strings. For another example, ((Fin D)=>(Fin s))=>Float is a table of Floats with s^D elements. Note that the parentheses are important!(Fin a)=>(Fin b)=>chas a * b elements, but ((Fin a)=>(Fin b))=>chas b ^ a elements.

Tokazama commented 2 years ago

@Tokazama Can Julia's type system can express index sets (and if it can, without heavy specialization cost)? ...Dex has a very particular design that relies on function types, sum types and the duality between memoized functions and arrays, among other things. It also does algebraic code rewriting.

I'm doing my best to understand the documentation for Dex (still not confident I'm understanding everything I read) but it really seems like what we've been working towards for a while in ArrayInterface already. The whole point of StaticInt was to enable that dichotomy between static and dynamic. If we returned a strictly limited set of types from to_indices on top of this we would be most of the way there. The index interface they have is very similar to what I historically pushed for in the fancy indexing and AxisArrays future conversations (admittedly I lost steam trying to convince others of this though and have stopped pushing for it).

It appears Dex has stuff formally converted to index types with @. We could always have something like @to_indices A[i1, i2, in...] and the macro could dispatch on whether the indices are Int or another expression so users never have to see StaticInt (or whatever other static type is deemed acceptable). So far we've been very focused on providing the infrastructure to support LoopVectorization in the most succinct and flexible way possible (not supper concerned about pretty syntax up to this point). It's used elsewhere too but I haven't been keeping track of that lately (I believe @cscherrer has similarly been using it to support the infrastructure of Tilde.jl).

In summary, I think we can have robust index types and if I've correctly understood the Dex prelude and tutorial we already have the basic building blocks. We just need to fill in the gaps with some types and think about a clean user interface.

cscherrer commented 2 years ago

I believe @cscherrer has similarly been using it to support the infrastructure of Tilde.jl

Not Tilde so much, but MeasureTheory uses Static.jl quite a bit. I could go into this, but a lot of it is off-topic from the OP.

It seems like we're basically talking about SizedArrays, or maybe something similar where we only know some subset of the axes. Is that right? I think this is an important direction.

One place I see this coming up... Say I have a vector, and I want to say "give me something to hold a vector of these". If the size is known statically, I might want this represented using ArraysOfArrays, so under the hood it's just a matrix. But that's a different type than Vector{<:Vector}, so for dispatch to work properly I'd need the size information for a single vector to be part of its type.

AriMKatz commented 2 years ago

@Tokazama Great! Hopefully we can see some settling in of those ideas. By the way, if you ever have any questions, feel free to ask the Dougal or Adam. They are super responsive on twitter (https://twitter.com/apaszke https://twitter.com/DougalMaclaurin) and github discussions

@cscherrer yea but I think this is far more general than sized arrays. It can express statically known or unknown sizes, statically sized named dims (with the names in the type, so type checked table ops ) , dims indexed by other arrays, static jagged, nested etc

Generator type things of all of the above but expressed in the type domain.

I'm not sure how they can get away without tons and tons of codegen..I think they just know more about the possible codepaths?

If the size is known statically, I might want this represented using ArraysOfArrays, so under the hood it's just a matrix.

Exactly! Despite all this abstraction, they claim that knowing the structure in the type allows the compiler to represent all that as a dense contiguous memory region

Edit: it's also not just the array shape, but the array shape is defined by the indices which are in the type