slimgroup / InvertibleNetworks.jl

A Julia framework for invertible neural networks
MIT License
152 stars 23 forks source link

add flag for computing p(x|y) in conditional HINT #27

Closed alisiahkoohi closed 2 years ago

alisiahkoohi commented 3 years ago

This pull request adds a flag to forward and inverse functions from NetworkConditionalHINT and ConditionalLayerHINT, which allows for conditional density estimation p(x|y) using the change-of-variable formula. Currently, the joint density p(x, y) can be easily estimated, whereas for the conditional density we need the logdet term associated with the x-lane only. This amounts to only accumulating logdet of the Jacobian with respect to x rather than with respect to (x, y).

After this change, given a trained NetworkConditionalHINT on pairs (x, y), the following change-of-variable formula can be used

using Distributions

# Compute p(X, Y) given a single X and Y
Zx, Zy, logdet = Net.forward(X, Y)
joint_density = sum(logpdf.(Normal(0f0, 1f0), Zx)) + sum(logpdf.(Normal(0f0, 1f0), Zy)) + logdet

# Compute p(X | Y) given a single X and Y
Zx, _, logdet = Net.forward(X, Y; x_lane=true)
conditional_density = sum(logpdf.(Normal(0f0, 1f0), Zx)) + logdet
codecov[bot] commented 3 years ago

Codecov Report

Merging #27 (8c2a3ad) into master (e22195d) will increase coverage by 0.01%. The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master      #27      +/-   ##
==========================================
+ Coverage   84.16%   84.18%   +0.01%     
==========================================
  Files          31       31              
  Lines        2476     2479       +3     
==========================================
+ Hits         2084     2087       +3     
  Misses        392      392              
Impacted Files Coverage Δ
src/conditional_layers/conditional_layer_hint.jl 94.85% <100.00%> (ø)
src/layers/invertible_layer_actnorm.jl 87.50% <100.00%> (ø)
src/layers/invertible_layer_basic.jl 90.78% <100.00%> (ø)
src/layers/invertible_layer_hint.jl 92.79% <100.00%> (ø)
...rc/networks/invertible_network_conditional_hint.jl 89.38% <100.00%> (ø)
src/utils/parameter.jl 64.70% <100.00%> (+1.62%) :arrow_up:

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update e22195d...8c2a3ad. Read the comment docs.

mloubout commented 3 years ago

Can we add a test for it? Except that looks fine

alisiahkoohi commented 3 years ago

Can we add a test for it? Except that looks fine

Yes, I will think about a simple test.

mloubout commented 2 years ago

Just saw above, it also has conflicts with master so definitely need rebase (git rebase -ii origin/master as I showed you the other time)