stan-dev / stanc3

The Stan transpiler (from Stan to C++ and beyond).
BSD 3-Clause "New" or "Revised" License
138 stars 44 forks source link

Updates the model to have a seperate log prob for reverse mode #1327

Closed SteveBronder closed 1 year ago

SteveBronder commented 1 year ago

Submission Checklist

Release notes

This adds a log_prob_impl specialization specifically for reverse mode. The main reason for making a separate log_prob_impl for reverse mode is to make the code gen simpler when dealing with Stuct of Array style matrices and (hopefully) soon matrices that can go on the GPU.

Instead of having to generate a conditional_var_value_t<...> type trait for deciding whether we can use SoA inside of log_prob, by making a specialization we can just generate a stan::math::var_value<...> whenever we see that a matrix is SoA. We will then be able to do a similar thing for matrices tagged with GPU

Copyright and Licensing

By submitting this pull request, the copyright holder is agreeing to license the submitted work under the BSD 3-clause license (https://opensource.org/licenses/BSD-3-Clause)

SteveBronder commented 1 year ago

@WardBrian I'll ping you when this is ready for review

codecov[bot] commented 1 year ago

Codecov Report

Merging #1327 (edc2a2b) into master (4d6664c) will increase coverage by 0.00%. The diff coverage is 97.50%.

@@           Coverage Diff           @@
##           master    #1327   +/-   ##
=======================================
  Coverage   88.79%   88.80%           
=======================================
  Files          64       64           
  Lines        9844     9860   +16     
=======================================
+ Hits         8741     8756   +15     
- Misses       1103     1104    +1     
Impacted Files Coverage Δ
src/analysis_and_optimization/Memory_patterns.ml 87.76% <ø> (ø)
src/frontend/Ast_to_Mir.ml 92.25% <ø> (ø)
src/middle/Program.ml 67.53% <75.00%> (+0.40%) :arrow_up:
...c/analysis_and_optimization/Dependence_analysis.ml 100.00% <100.00%> (ø)
src/analysis_and_optimization/Optimize.ml 92.69% <100.00%> (+0.04%) :arrow_up:
src/stan_math_backend/Cpp.ml 86.12% <100.00%> (+0.18%) :arrow_up:
src/stan_math_backend/Lower_expr.ml 92.47% <100.00%> (-0.05%) :arrow_down:
src/stan_math_backend/Lower_program.ml 99.09% <100.00%> (+0.01%) :arrow_up:
src/stan_math_backend/Transform_Mir.ml 94.51% <100.00%> (+0.01%) :arrow_up:
SteveBronder commented 1 year ago

@WardBrian Alright I think this is ready!

Using an optional type for reverse_mode_log_prob made for some kind of goofy logic branches down in the code generation. So instead I just start reverse_mode_log_prob with an empty list, then in transform_mir I assign the transformed log_prob to reverse_mode_log_prob. Then we just pass it through the optimizations and code gen.

This always produces a reverse mode specialization for log prob, which imo I think is fine and good. Once we have this in there are several smaller optimizations I'd like to look at doing on the lir that I think would be nice. One example is making the parameters typed auto instead of hard types as that will let Eigen matrices and vectors of parameters stay as Eigen::Map<> types which never need copied by that math library

WardBrian commented 1 year ago

It's worth grepping around for existing usages of log_prob that need to be updated. For example, the --debug-mem-patterns flag code:

https://github.com/stan-dev/stanc3/blob/0b1d1cbeb0d4e5cdc1b732bb011340081a3e6a6c/src/analysis_and_optimization/Memory_patterns.ml#L653-L662

SteveBronder commented 1 year ago

@WardBrian ty I think I got everything! I did a grep for log_prob and I think I covered everywhere

WardBrian commented 1 year ago

Seems like a copy/paste error led to both overloads now being require_st_var

WardBrian commented 1 year ago

Looks good to me. Do you want to squash out any of the debugging commits before a merge? I'll leave it up to you