facebookincubator / flowtorch

This library would form a permanent home for reusable components for deep probabilistic programming. The library would form and harness a community of users and contributors by focusing initially on complete infra and documentation for how to use and create components.
https://flowtorch.ai
MIT License
300 stars 21 forks source link

Fixed bug in `bijectors.ops.spline.Spline` and unit test for log(detJ)) #100

Closed stefanwebb closed 2 years ago

stefanwebb commented 2 years ago

Motivation

From other work, I discovered that training Neural Spline Flows was not working as expected, being unable to learn simple toy distributions... This PR fixes this, as well as the reason that the unit tests were not picking it up.

Changes proposed

I changed the sign of the log(det(J)) in the inverse method of bijectors.ops.spline.Spline, and ensured the unit tests are not using cached values of log(det(J)) (via BijectiveTensor) when comparing log(det(J)) from the forward method to that of the inverse one.

Test Plan

Run pytest tests/ and try the Neural Spline Flow example in the theory tutorials

codecov-commenter commented 2 years ago

Codecov Report

Merging #100 (268ac28) into main (45e91ca) will increase coverage by 0.01%. The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #100      +/-   ##
==========================================
+ Coverage   98.23%   98.25%   +0.01%     
==========================================
  Files           6        6              
  Lines         227      229       +2     
==========================================
+ Hits          223      225       +2     
  Misses          4        4              
Flag Coverage Δ
unittests 98.25% <100.00%> (+0.01%) :arrow_up:

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
tests/test_bijector.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 45e91ca...268ac28. Read the comment docs.

facebook-github-bot commented 2 years ago

@stefanwebb has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.