april-tools / cirkit

a python framework to build, learn and reason about probabilistic circuits and tensor networks
https://cirkit-docs.readthedocs.io/en/latest/
GNU General Public License v3.0
80 stars 4 forks source link

Refactor sum layers + consistency checks + bug fixes #314

Closed loreloc closed 3 days ago

loreloc commented 2 weeks ago

As we discussed and to avoid any confusion, this merges dense and mixing layers into a more general sum layer, as formalized in https://arxiv.org/abs/2409.07953.

Moreover, this adds sanity checks on the symbolic representation (symbolic circuit, parameterizations and initializations), thus preventing some bugs (e.g. shape mismatch) to happen after compilation and optimizations.

Finally, as I was already changing the implementation of the sum layer in torch, this PR also fixes the bug causing sum layers to output NaNs when the input to a sum unit are all +-infs.

Closes #309 #316 #319

Detailed changes:

codecov[bot] commented 2 weeks ago

Codecov Report

Attention: Patch coverage is 57.47126% with 37 lines in your changes missing coverage. Please review.

Project coverage is 68.25%. Comparing base (3e51b42) to head (9a48049).

Files with missing lines Patch % Lines
cirkit/templates/circuit_templates/utils.py 0.00% 15 Missing :warning:
cirkit/backend/torch/layers/inner.py 45.45% 6 Missing :warning:
cirkit/symbolic/circuit.py 50.00% 4 Missing and 1 partial :warning:
cirkit/backend/torch/parameters/pic.py 0.00% 3 Missing :warning:
cirkit/templates/circuit_templates/data.py 0.00% 3 Missing :warning:
cirkit/backend/torch/compiler.py 50.00% 1 Missing and 1 partial :warning:
cirkit/backend/torch/optimization/registry.py 60.00% 2 Missing :warning:
cirkit/backend/torch/optimization/layers.py 94.73% 1 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #314 +/- ## ========================================== + Coverage 68.18% 68.25% +0.07% ========================================== Files 51 51 Lines 5368 5314 -54 Branches 616 614 -2 ========================================== - Hits 3660 3627 -33 + Misses 1492 1473 -19 + Partials 216 214 -2 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.