april-tools / cirkit

a python framework to build, learn and reason about probabilistic circuits and tensor networks
https://cirkit-docs.readthedocs.io/en/latest/
GNU General Public License v3.0
71 stars 1 forks source link

Default initialisations can produce nan loss #279

Open andreasgrv opened 1 week ago

andreasgrv commented 1 week ago

Code to reproduce:

import random
import numpy as np
import torch

from torch import optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets

from cirkit.symbolic.circuit import Circuit
from cirkit.templates.region_graph import RandomBinaryTree
from cirkit.symbolic.layers import CategoricalLayer
from cirkit.templates.circuit_templates._factories import name_to_parameter_factory, name_to_initializer
from cirkit.pipeline import compile

NUM_INPUT_UNITS = 64
NUM_SUM_UNITS = 64
PIXEL_RANGE=255

# Load the MNIST data set and data loaders
transform = transforms.Compose([
    transforms.ToTensor(),
    # Set pixel values in the [0-255] range
    transforms.Lambda(lambda x: (PIXEL_RANGE * x).long())
])

def define_circuit_from_rg(rg):

    # Here is where Overparametrisation comes in
    input_factory = lambda x, y, z: CategoricalLayer(scope=x,
                                                     num_categories=PIXEL_RANGE+1,
                                                     num_channels=1, # These are grayscale images
                                                     num_output_units=NUM_INPUT_UNITS # Overparametrisation
                                                    )

    ### =========== With init below model trains fine ===================================
    #  sum_weight_init = name_to_initializer('normal')
    #  sum_weight_params = name_to_parameter_factory('softmax', initializer=sum_weight_init)
    ### ========== but if no init - as below, we get nan loss ===========================
    sum_weight_params = None   # This line leads to nan loss

    circuit = Circuit.from_region_graph(rg,
                                        input_factory=input_factory,
                                        sum_weight_factory= sum_weight_params,
                                        num_sum_units=NUM_SUM_UNITS,
                                        sum_product='cp')
    return circuit

def train_circuit(cc):

    # Set some seeds
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    # torch.cuda.manual_seed(42)

    # Set the torch device to use
    device = torch.device('cuda')

    # Compile the circuit
    circuit = compile(cc)

    # Move the circuit to chosen device
    circuit = circuit.to(device)

    num_epochs = 5
    step_idx = 0
    running_loss = 0.0

    # Initialize a torch optimizer of your choice,
    #  e.g., Adam, by passing the parameters of the circuit
    optimizer = optim.Adam(circuit.parameters(), lr=0.01)

    for epoch_idx in range(num_epochs):
        for i, (batch, _) in enumerate(train_dataloader):
            # The circuit expects an input of shape (batch_dim, num_channels, num_variables),
            # so we unsqueeze a dimension for the channel.
            BS = batch.shape[0]
            batch = batch.view(BS, 1, -1).to(device)

            # Compute the log-likelihoods of the batch, by evaluating the circuit
            log_likelihoods = circuit(batch)

            # We take the negated average log-likelihood as loss
            loss = -torch.mean(log_likelihoods)
            loss.backward()
            # Update the parameters of the circuits, as any other model in PyTorch
            optimizer.step()
            optimizer.zero_grad()
            running_loss += loss.detach() * len(batch)
            step_idx += 1
            if step_idx % 100 == 0:
                print(f"Step {step_idx}: Average NLL: {running_loss / (100 * len(batch)):.3f}")
                running_loss = 0.0

data_train = datasets.MNIST('datasets', train=True, download=True, transform=transform)
train_dataloader = DataLoader(data_train, shuffle=True, batch_size=256)

# We can also specify depth and number of repetitions
# depth=None means maximum possible
rnd = RandomBinaryTree(28*28, depth=None, num_repetitions=1)

circuit = define_circuit_from_rg(rnd)

train_circuit(circuit)

In the above code when the sum weight parameterisation is not specified, the result is a loss of nan during training. This may be confusing for somebody not familiar with the internals of the library - is there a way to avoid this?

andreasgrv commented 1 week ago

Output for

    sum_weight_params = None   # This line leads to nan loss

    circuit = Circuit.from_region_graph(rg,
                                        input_factory=input_factory,
                                        sum_weight_factory= sum_weight_params,
                                        num_sum_units=NUM_SUM_UNITS,
                                        sum_product='cp')

python example.py Step 100: Average NLL: nan Step 200: Average NLL: nan Step 300: Average NLL: nan

On the other hand, if:

    sum_weight_init = name_to_initializer('normal')
    sum_weight_params = name_to_parameter_factory('softmax', initializer=sum_weight_init)

    circuit = Circuit.from_region_graph(rg,
                                        input_factory=input_factory,
                                        sum_weight_factory= sum_weight_params,
                                        num_sum_units=NUM_SUM_UNITS,
                                        sum_product='cp')

python example.py Step 100: Average NLL: 3422.423 Step 200: Average NLL: 1614.733 Step 300: Average NLL: 1013.035

lkct commented 1 week ago

This is due to sum weights being inited to Normal by default, but they are expected to be positive in "common" circuits, and negative values generate nan in log-sum-exp.

However we also have many projects using negative weights (with sum-product or complex-lse-sum semiring), so it makes sense to use Normal init.

This may be confusing for somebody not familiar with the internals of the library - is there a way to avoid this?

Considering this, I would agree to change the default init for sum.

But in any way, we should properly doc the default init for layers and tell the users when they should NOT rely on the default.