lawmurray / Birch

A probabilistic programming language that combines automatic differentiation, automatic marginalization, and automatic conditioning within Monte Carlo methods.
https://birch.sh
Apache License 2.0
109 stars 14 forks source link

Use Welford's online algorithm in tests #20

Closed devmotion closed 2 years ago

devmotion commented 2 years ago

While working on and debugging https://github.com/lawmurray/Birch/pull/18 I noticed that the location and spread of the student's t proposals in test_z is computed with the "naive" algorithm that is prone to catastrophic cancellation. Initially I thought this could be the reason for the test errors in #18 but it turned out that was not the case (rather nan was not handled correctly in the initial commits). Nevertheless, I thought possibly it could be useful to improve numerical stability of the tests and hence I put these local changes in a separate PR. Additionally, the PR uses expm1 in the computation of the deviation δ to increase numerical stability there as well.

codecov[bot] commented 2 years ago

Codecov Report

Merging #20 (8a3b2d0) into numeric (92269ac) will decrease coverage by 0.13%. The diff coverage is 100.00%.

@@             Coverage Diff             @@
##           numeric      #20      +/-   ##
===========================================
- Coverage    81.24%   81.11%   -0.13%     
===========================================
  Files          446      435      -11     
  Lines        18272    17855     -417     
===========================================
- Hits         14845    14483     -362     
+ Misses        3427     3372      -55     
Impacted Files Coverage Δ
tests/Test/src/test_z.birch 100.00% <100.00%> (ø)
...braries/Standard/src/io/MatrixBufferIterator.birch 0.00% <0.00%> (-100.00%) :arrow_down:
birch/src/statement/Factor.cpp 14.28% <0.00%> (-85.72%) :arrow_down:
...braries/Standard/src/container/EmptyIterator.birch 0.00% <0.00%> (-66.67%) :arrow_down:
libraries/Standard/src/primitive/filesystem.birch 63.33% <0.00%> (-20.01%) :arrow_down:
libraries/Standard/src/container/Array.birch 58.90% <0.00%> (-19.18%) :arrow_down:
libraries/Standard/src/primitive/matrix.birch 85.71% <0.00%> (-14.29%) :arrow_down:
libraries/Standard/src/container/RaggedArray.birch 62.50% <0.00%> (-11.03%) :arrow_down:
libraries/Standard/src/event/handle.birch 91.42% <0.00%> (-8.58%) :arrow_down:
libraries/Standard/src/io/YAMLWriter.birch 64.23% <0.00%> (-6.57%) :arrow_down:
... and 24 more

:mega: We’re building smart automated test selection to slash your CI/CD build times. Learn more

lawmurray commented 2 years ago

Incidentally, on this theme, are there parallel versions of this and the log_sum_exp algorithm (e.g. using prefix scans)? Seems there should be. Would be useful for the CUDA backend of the new NumBirch library, if that's of interest.

devmotion commented 2 years ago

Both algorithms can be parallelized by operating on subsets in parallel and merging the results. Wikipedia explains the parallel version of Welford's algorithm. Similarly, the parallel version of the log_sum_exp algorithm is actually a generalization of the sequential one. The initial commits in #18 used transform_reduce and the more general operations (e.g. https://github.com/lawmurray/Birch/pull/18/commits/d206a76b3e1e137415b491c3c54d40c80c743a93) but I changed it to a for-loop since it seemed transform_reduce would use a for loop anyway (https://github.com/lawmurray/Birch/blob/92269ac3be4c45b9574df4bc2eb0c9c163c1df35/libraries/Standard/src/primitive/primitive.birch#L230-L237).

lawmurray commented 2 years ago

Thanks @devmotion, merged to numeric and back to master.