april-tools / cirkit

a python framework to build, learn and reason about probabilistic circuits and tensor networks
https://cirkit-docs.readthedocs.io/en/latest/
GNU General Public License v3.0
71 stars 1 forks source link

Sampling from a categorical folded layer with num_channels > 1 does not work #295

Open gengala opened 20 hours ago

gengala commented 20 hours ago

I tried to sample from a PC built for RGB data, and the returned samples only have one channel.

loreloc commented 16 hours ago

Can you please post a small piece of code reproducing the issue?

gengala commented 16 hours ago

from cirkit.templates import circuit_templates
symbolic_circuit = circuit_templates.image_data(
    (3, 28, 28),                # The shape of MNIST image, i.e., (num_channels, image_height, image_width)
    region_graph='quad-graph',  # Select the structure of the circuit to follow the QuadGraph region graph
    input_layer='categorical',  # Use Categorical distributions for the pixel values (0-255) as input layers
    num_input_units=64,         # Each input layer consists of 64 Categorical input units
    sum_product_layer='cp',     # Use CP sum-product layers, i.e., alternate dense layers with Hadamard product layers
    num_sum_units=64,           # Each dense sum layer consists of 64 sum units
    sum_weight_param=circuit_templates.Parameterization(
        activation='softmax',   # Parameterize the sum weights by using a softmax activation
        initialization='normal' # Initialize the sum weights by sampling from a standard normal distribution
    )
)

from cirkit.pipeline import compile
circuit = compile(symbolic_circuit)

from cirkit.backend.torch.queries import SamplingQuery
sampling_query = SamplingQuery(circuit)
sample = sampling_query(num_samples=1)[0] # has shape [1, 1, 784], instead of [1, 3, 784]
gengala commented 16 hours ago

Moreover, I just noticed that if you change shape to (3, 32 32), and then call sampling_query, you get the following Error:

ValueError: The circuit to sample from must be smooth and decomposable, but found StructuralProperties(smooth=False, decomposable=True, structured_decomposable=False, omni_compatible=False)