april-tools / cirkit

a python framework to build, learn and reason about probabilistic circuits and tensor networks
https://cirkit-docs.readthedocs.io/en/latest/
GNU General Public License v3.0
80 stars 4 forks source link

Batched marginalisation mask #318

Closed andreasgrv closed 1 day ago

andreasgrv commented 2 weeks ago

This PR closes #292. We now allow for different marginalisation masks to be applied for each entry in a batch. At the same time, we also allow a single scope to be broadcasted across all entries in a batch, so this change should be backward compatible.

More concretely, the input to IntegrateQuery can now be:

  1. a torch tensor of shape (B, D) where B is the batch size and D is the number of variables in the scope of the circuit. The tensor's dtype should be torch.bool it should have True in the positions of random variables that should be marginalised out and False elsewhere.
    inputs = # tensor with batch_size=2
    mar_query = IntegrateQuery(circuit)
    # Integrate out 1, 3 from first example and 0 from second example
    mask = torch.tensor([[False, True, False, True], [True, False, False, False]], dtype=torch.bool)
    mar_scores = mar_query(inputs, integrate_vars=mask)
  2. a list of scopes:
    inputs = # tensor with batch_size=2
    mar_query = IntegrateQuery(circuit)
    mar_scores = mar_query(inputs, integrate_vars=[Scope(1,3), Scope(0)])
  3. a single scope, in which case the integration mask is broadcasted across the batch:
    inputs = # tensor with batch_size=some integer
    mar_query = IntegrateQuery(circuit)
    mar_scores = mar_query(inputs, integrate_vars=Scope(1,3))

Due to 1., each entry in the batch can have a scope over a different number of variables - and this can be an issue when using pytorch, since pytorch tensors have fixed size for each dimension. The solution at the moment is to use a boolean mask of size (batch_size, num_variables), where num_variables is an upper bound on the number of variables in the scope of the circuit (see below).

Assumptions

We assume the size of the scope is <= max(scope), i.e. the maximum int in the scope. We need this since the actual number of variables may change - i.e. some ids may be dropped and len(scope) may be invalid, as highlighted by @loreloc.

Future work

Deal with sparsity

We currently expand the list of scopes into a dense boolean tensor mask. If there is a very large number of variables and the integration mask is sparse, it would make sense to replace the dense implementation with a sparse one, e.g. see Sparse Coo Tensor.