Currently TensorizedPC and all layers only accept a fixed number of dimensions, with one batch dimension.
However I find it useful to expose an interface that accepts any number of dimensions, similar to torch.nn.Linear, where transformation on specific dimension(s) and all batch dimensions are kept as-is.
The implementation can be as easy as saving batch dims and flattening the inputs, and restoring batch dims for the results. (The user can also do this, but it's more convenient to implement in the library.)
(Or is there a smarter way to implement. And how does this interact with integration/derivation?)
Currently
TensorizedPC
and all layers only accept a fixed number of dimensions, with one batch dimension. However I find it useful to expose an interface that accepts any number of dimensions, similar to torch.nn.Linear, where transformation on specific dimension(s) and all batch dimensions are kept as-is.The implementation can be as easy as saving batch dims and flattening the inputs, and restoring batch dims for the results. (The user can also do this, but it's more convenient to implement in the library.) (Or is there a smarter way to implement. And how does this interact with integration/derivation?)