slimgroup / InvertibleNetworks.jl

A Julia framework for invertible neural networks
MIT License
148 stars 20 forks source link

Add simple MNIST conditional example. #64

Closed rafaelorozco closed 1 month ago

rafaelorozco commented 1 year ago

This is meant to be a bare-minimum-but-reasonable-results conditional sampling example on images. Runs on CPU in 3 minutes and produces OK results. Users can run for longer epochs and with more training samples to get better results.

Added an even more minimal example to the README so that users can see the basic workflow of training conditional flow and sampling it.

To get things clean, I made some changes to conditional glow. Basically some complexity regarding the reshaping of the conditioning variable is hidden from the user so that sampling looks exactly like the math of C-INN:

num_samples = 64
Y_repeat = repeat(y, 1, 1, 1, num_samples) ;
ZX_noise = randn(Float32, nx, ny, n_chan, num_samples) 
X_post = G.inverse(ZX_noise, Y_repeat);
codecov[bot] commented 1 year ago

Codecov Report

Base: 88.06% // Head: 35.69% // Decreases project coverage by -52.37% :warning:

Coverage data is based on head (6f30f61) compared to base (6d2fba0). Patch coverage: 6.15% of modified lines in pull request are covered.

:exclamation: Current head 6f30f61 differs from pull request most recent head 36a3688. Consider uploading reports for the commit 36a3688 to get more accurate results

Additional details and impacted files ```diff @@ Coverage Diff @@ ## master #64 +/- ## =========================================== - Coverage 88.06% 35.69% -52.38% =========================================== Files 33 33 Lines 2439 2443 +4 =========================================== - Hits 2148 872 -1276 - Misses 291 1571 +1280 ``` | [Impacted Files](https://codecov.io/gh/slimgroup/InvertibleNetworks.jl/pull/64?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=slimgroup) | Coverage Δ | | |---|---|---| | [src/layers/invertible\_layer\_glow.jl](https://codecov.io/gh/slimgroup/InvertibleNetworks.jl/pull/64/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=slimgroup#diff-c3JjL2xheWVycy9pbnZlcnRpYmxlX2xheWVyX2dsb3cuamw=) | `0.00% <0.00%> (-97.30%)` | :arrow_down: | | [...rc/networks/invertible\_network\_conditional\_glow.jl](https://codecov.io/gh/slimgroup/InvertibleNetworks.jl/pull/64/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=slimgroup#diff-c3JjL25ldHdvcmtzL2ludmVydGlibGVfbmV0d29ya19jb25kaXRpb25hbF9nbG93Lmps) | `0.00% <0.00%> (-82.09%)` | :arrow_down: | | [src/networks/invertible\_network\_glow.jl](https://codecov.io/gh/slimgroup/InvertibleNetworks.jl/pull/64/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=slimgroup#diff-c3JjL25ldHdvcmtzL2ludmVydGlibGVfbmV0d29ya19nbG93Lmps) | `0.00% <0.00%> (-89.00%)` | :arrow_down: | | [src/layers/layer\_residual\_block.jl](https://codecov.io/gh/slimgroup/InvertibleNetworks.jl/pull/64/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=slimgroup#diff-c3JjL2xheWVycy9sYXllcl9yZXNpZHVhbF9ibG9jay5qbA==) | `88.31% <100.00%> (-10.39%)` | :arrow_down: | | [src/utils/activation\_functions.jl](https://codecov.io/gh/slimgroup/InvertibleNetworks.jl/pull/64/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=slimgroup#diff-c3JjL3V0aWxzL2FjdGl2YXRpb25fZnVuY3Rpb25zLmps) | `72.30% <100.00%> (-16.93%)` | :arrow_down: | | [...itional\_layers/conditional\_layer\_residual\_block.jl](https://codecov.io/gh/slimgroup/InvertibleNetworks.jl/pull/64/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=slimgroup#diff-c3JjL2NvbmRpdGlvbmFsX2xheWVycy9jb25kaXRpb25hbF9sYXllcl9yZXNpZHVhbF9ibG9jay5qbA==) | `0.00% <0.00%> (-100.00%)` | :arrow_down: | | [src/conditional\_layers/conditional\_layer\_hint.jl](https://codecov.io/gh/slimgroup/InvertibleNetworks.jl/pull/64/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=slimgroup#diff-c3JjL2NvbmRpdGlvbmFsX2xheWVycy9jb25kaXRpb25hbF9sYXllcl9oaW50Lmps) | `0.00% <0.00%> (-99.19%)` | :arrow_down: | | [src/layers/invertible\_layer\_irim.jl](https://codecov.io/gh/slimgroup/InvertibleNetworks.jl/pull/64/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=slimgroup#diff-c3JjL2xheWVycy9pbnZlcnRpYmxlX2xheWVyX2lyaW0uamw=) | `0.00% <0.00%> (-98.12%)` | :arrow_down: | | [src/conditional\_layers/conditional\_layer\_glow.jl](https://codecov.io/gh/slimgroup/InvertibleNetworks.jl/pull/64/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=slimgroup#diff-c3JjL2NvbmRpdGlvbmFsX2xheWVycy9jb25kaXRpb25hbF9sYXllcl9nbG93Lmps) | `0.00% <0.00%> (-97.73%)` | :arrow_down: | | [src/layers/layer\_affine.jl](https://codecov.io/gh/slimgroup/InvertibleNetworks.jl/pull/64/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=slimgroup#diff-c3JjL2xheWVycy9sYXllcl9hZmZpbmUuamw=) | `0.00% <0.00%> (-97.23%)` | :arrow_down: | | ... and [21 more](https://codecov.io/gh/slimgroup/InvertibleNetworks.jl/pull/64/diff?src=pr&el=tree-more&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=slimgroup) | | Help us with your feedback. Take ten seconds to tell us [how you rate us](https://about.codecov.io/nps?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=slimgroup). Have a feature suggestion? [Share it here.](https://app.codecov.io/gh/feedback/?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=slimgroup)

:umbrella: View full report at Codecov.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.