TuringLang / Bijectors.jl

Implementation of normalising flows and constrained random variable transformations
https://turinglang.org/Bijectors.jl/
MIT License
200 stars 33 forks source link

Replace NonlinearSolve with Roots and fix the ReverseDiff adjoint of `find_alpha` #202

Closed devmotion closed 2 years ago

devmotion commented 2 years ago

This PR replaces NonlinearSolve with Roots for solving the root finding problem in the inverse of the planar layer.

I fixed some type inference problems in Roots some weeks ago (https://github.com/JuliaMath/Roots.jl/pull/245) and some time ago also the convergence and AD issues of this algorithm were fixed by https://github.com/TuringLang/Bijectors.jl/pull/126 by using an improved initial bracket and custom adjoints (BTW the ReverseDiff adjoint was never used and broken, it's fixed in this PR as well); so the motivation and reasons for https://github.com/TuringLang/Bijectors.jl/pull/155 (move from Roots to NonlinearSolve) are gone.

The main motivation for this PR is (apart from the ReverseDiff bugfix) that it reduces the dependencies of Bijectors quite dramatically and hence makes the package more lightweight (https://github.com/TuringLang/Bijectors.jl/issues/199). With the master branch, removing NonlinearSolve and adding Roots removes and adds the following indirect dependencies (note that all indirect dependencies added by Roots are already present due to NonlinearSolve):

(Bijectors) pkg> rm NonlinearSolve
    Updating `~/.julia/dev/Bijectors/Project.toml`
  [8913a72c] - NonlinearSolve v0.3.11
    Updating `~/.julia/dev/Bijectors/Manifest.toml`
  [79e6a3ab] - Adapt v3.3.1
  [4fba245c] - ArrayInterface v3.1.33
  [62783981] - BitTwiddlingConvenienceFunctions v0.1.0
  [2a0fbf3d] - CPUSummary v0.1.5
  [fb6a15b2] - CloseOpenIntervals v0.1.2
  [38540f10] - CommonSolve v0.2.0
  [bbf7d656] - CommonSubexpressions v0.3.0
  [187b0558] - ConstructionBase v1.3.0
  [e2d170a0] - DataValueInterfaces v1.0.0
  [163ba53b] - DiffResults v1.0.3
  [b552c78f] - DiffRules v1.3.1
  [6a86dc24] - FiniteDiff v2.8.1
  [f6369f11] - ForwardDiff v0.10.21
  [3e5b6fbb] - HostCPUFeatures v0.1.4
  [0e44f5e4] - Hwloc v2.0.0
  [615f187c] - IfElse v0.1.0
  [42fd0dbc] - IterativeSolvers v0.9.1
  [82899510] - IteratorInterfaceExtensions v1.0.0
  [10f19ff3] - LayoutPointers v0.1.3
  [bdcacae8] - LoopVectorization v0.12.85
  [1914dd2f] - MacroTools v0.5.8
  [d125e4d3] - ManualMemory v0.1.6
  [77ba4419] - NaNMath v0.3.5
  [8913a72c] - NonlinearSolve v0.3.11
  [6fe1bfb0] - OffsetArrays v1.10.7
  [f517fe37] - Polyester v0.5.3
  [1d0040c9] - PolyesterWeave v0.1.0
  [3cdcf5f2] - RecipesBase v1.1.2
  [731186ca] - RecursiveArrayTools v2.19.1
  [f2c3362d] - RecursiveFactorization v0.2.4
  [3cdde19b] - SIMDDualNumbers v0.1.0
  [94e857df] - SIMDTypes v0.1.0
  [476501e8] - SLEEFPirates v0.6.27
  [0bca4576] - SciMLBase v1.19.2
  [efcf1570] - Setfield v0.8.0
  [aedffcd0] - Static v0.3.3
  [90137ffa] - StaticArrays v1.2.13
  [7792a7ef] - StrideArraysCore v0.2.5
  [3783bdb8] - TableTraits v1.0.1
  [bd369af6] - Tables v1.6.0
  [8290d209] - ThreadingUtilities v0.4.6
  [a2a6695c] - TreeViews v0.3.0
  [d5829a12] - TriangularSolve v0.1.6
  [3a884ed6] - UnPack v1.0.2
  [3d5dd08c] - VectorizationBase v0.21.13
  [700de1a5] - ZygoteRules v0.2.2
  [e33a78d0] - Hwloc_jll v2.5.0+0
  [9fa8497b] - Future

(Bijectors) pkg> add Roots
   Resolving package versions...
    Updating `~/.julia/dev/Bijectors/Project.toml`
  [f2b01f46] + Roots v1.3.5
    Updating `~/.julia/dev/Bijectors/Manifest.toml`
  [38540f10] + CommonSolve v0.2.0
  [187b0558] + ConstructionBase v1.3.0
  [1914dd2f] + MacroTools v0.5.8
  [f2b01f46] + Roots v1.3.5
  [efcf1570] + Setfield v0.8.0
  [9fa8497b] + Future
torfjelde commented 2 years ago

Eeehh I should have let you merge in case you weren't done with it (realized you hadn't requested a review yet), though it looks good to me. I'll wait with release until you given me a thumbs-up that it's all good :+1:

devmotion commented 2 years ago

No worries, all good 👍

devmotion commented 2 years ago

https://github.com/JuliaRegistries/General/pull/46848