stan-dev / stanc3

The Stan transpiler (from Stan to C++ and beyond).
BSD 3-Clause "New" or "Revised" License
140 stars 44 forks source link

[FR] Allow SoA for UDFs #1237

Open SteveBronder opened 2 years ago

SteveBronder commented 2 years ago

Is your feature request related to a problem? Please describe.

It would be nice to extend the Struct of Arrays (SoA) framework to support UDFs so that they can be used in reduce_sum and other higher order functions.

Right now if a user calls a higher order function we have to demote every matrix / vector passed to that function to Array of Structs (AoS). This is unfortunate since reduce_sum is very powerful for large independent blocks of data and parameters.

Describe the solution you'd like

I think we can do this by the following

  1. During the SoA optimization pass, when the optimization hits a UDF or a higher order function it starts a sub-call of the SoA optimization for the UDF. It will just return the list of inputs that cannot be SoA and then continue the rest of the larger optimization pass.
  2. At the end of the SoA optimization pass the program runs another pass over the program collecting which matrices are SoA. Then when it comes to a UDF in the program it looks at that call of the UDFs argument memory type (Either SoA or AoS) and appends that set of argument memory types to a list in the UDFs meta record. So now each UDF defined in the functions block knows what combinations of AoS and SoA arguments it needs to generate.
  3. When the program starts printing out the C++, it will go through each UDF's list of memory patterns and generate a signature and body for each

I think the above will work? It sounds like it's only 3 steps but there's a lot of little things to do in all of those.

WardBrian commented 2 years ago

This seems reasonable. Part 3 shouldn't be too bad since we already have overloading code generation working. Something like #1233 would make it even easier I think

Since we have the inliner working could we cut down on the amount of work the optimizer needs to do if we only run step 1 for functions which are used in the higher order functions?