april-tools / cirkit

a python framework to build, learn and reason about probabilistic circuits and tensor networks
https://cirkit-docs.readthedocs.io/en/latest/
GNU General Public License v3.0
71 stars 1 forks source link

Allow any number of batch dimensions for `TensorizedPC` #132

Closed lkct closed 11 months ago

lkct commented 1 year ago

Currently TensorizedPC and all layers only accept a fixed number of dimensions, with one batch dimension. However I find it useful to expose an interface that accepts any number of dimensions, similar to torch.nn.Linear, where transformation on specific dimension(s) and all batch dimensions are kept as-is.

The implementation can be as easy as saving batch dims and flattening the inputs, and restoring batch dims for the results. (The user can also do this, but it's more convenient to implement in the library.) (Or is there a smarter way to implement. And how does this interact with integration/derivation?)

lkct commented 12 months ago

With torch.einsum, ellipsis can be used to implement this. The solution will be as easy as replacing all batch letters with ellipses.