GFNOrg / torchgfn

A modular, easy to extend GFlowNet library
https://torchgfn.readthedocs.io/en/latest/
Other
238 stars 31 forks source link

[help wanted] Traning a DiscreteEBM ends up with "Log probabilities are inf. This should not happen." #136

Closed ermiaetemadi closed 8 months ago

ermiaetemadi commented 1 year ago

Hi. I'm trying to train a DiscreteEBM environment for square lattice ising model. It's working fine with small grid lengths but when I attempt to increase the grid length i get this: RuntimeError: Log probabilities are inf. This should not happen.

I can search for the appropriate batch size for each length with trial and error but i can't understand why it is happening.

My code:

import torch

from tqdm import tqdm
import wandb
from argparse import ArgumentParser

from gfn.gym import DiscreteEBM
from gfn.gym.discrete_ebm import IsingModel
from gfn.gflownet import FMGFlowNet
from gfn.utils.modules import NeuralNet
from gfn.modules import DiscretePolicyEstimator
from gfn.utils.common import validate

def main(args):

    # Configs

    use_wandb = len(args.wandb_project) > 0
    if use_wandb:
        wandb.init(project=args.wandb_project)
        wandb.config.update(args)

    device =  "cpu"
    torch.set_num_threads(args.n_threads)

    hidden_dim = 512
    n_hidden = 2
    acc_fn = "relu"
    lr = 0.001
    lr_Z = 0.01
    L = args.L

    validation_samples = 1000

    # Ising model parameters

    def ising_n_to_ij(L, n):

        i = n // L
        j = n - i*L

        return (i, j)

    N = L**2
    J = torch.zeros((N, N), device=torch.device(device))
    for k in range(N):
        for m in range(k):

            x1, y1 = ising_n_to_ij(L, k)
            x2, y2 = ising_n_to_ij(L, m) 
            if x1 == x2 and abs(y2 - y1) == 1:
                J[k][m] = 1
                J[m][k] = 1
            elif y1 == y2 and abs(x2 - x1) == 1:
                J[k][m] = 1
                J[m][k] = 1

    for k in range(L):

        J[k*L][(k+1)*L - 1] = 1
        J[(k+1)*L - 1][k*L] = 1
        J[k][k+N-L] = 1
        J[k+N-L][k] = 1

    J = args.J * J

    # Ising model env

    ising_energy = IsingModel(J)
    ising_env = DiscreteEBM(N, alpha=1, energy=ising_energy, device_str=device)

    # Parametrization and losses

    pf_module = NeuralNet(
                    input_dim=ising_env.preprocessor.output_dim,
                    output_dim=ising_env.n_actions,
                    hidden_dim=hidden_dim,
                    n_hidden_layers=n_hidden,
                    activation_fn=acc_fn
                )

    pf_estimator = DiscretePolicyEstimator(env=ising_env, module=pf_module, forward=True)

    gflownet = FMGFlowNet(pf_estimator)

    # Optimizer

    params = [
            {
                "params": [
                    v for k, v in dict(gflownet.named_parameters()).items() if k != "logZ"
                ],
                "lr": lr,
            }
        ]

    if "logZ" in dict(gflownet.named_parameters()):
            params.append(
                {
                    "params": [dict(gflownet.named_parameters())["logZ"]],
                    "lr": lr_Z,
                }
            )

    optimizer = torch.optim.Adam(params)

    # Learning

    visited_terminating_states = ising_env.States.from_batch_shape((0,))

    states_visited = 0

    for i in (pbar := tqdm(range(10000))):
        trajectories = gflownet.sample_trajectories(n_samples=8)
        training_samples = gflownet.to_training_samples(trajectories)
        optimizer.zero_grad()
        loss = gflownet.loss(training_samples)
        loss.backward()
        optimizer.step()

        states_visited += len(trajectories)
        to_log = {"loss": loss.item(), "states_visited": states_visited}

        if i % 25 == 0:
            tqdm.write(f"{i}: {to_log}")

if __name__ == "__main__":

    # Comand-line arguments
    parser = ArgumentParser()

    parser.add_argument(
            "--n_threads",
            type=int,
            default=4,
            help="Number of threads used by PyTorch",
        )

    parser.add_argument(
            "-L",
            type=int,
            default=16,
            help="Lentgh of the grid",
        )

    parser.add_argument(
            "-J",
            type=float,
            default=0.44,
            help="J (Magnetic coupling constant)",
        )

    parser.add_argument(
            "--wandb_project",
            type=str,
            default="",
            help="Name of the wandb project. If empty, don't use wandb",
        )

    args = parser.parse_args()

    main(args)

I ran this with L=10 and got the error on 844th iteration.

josephdviviano commented 1 year ago

Thanks - we have a few outstanding bugs that were introduced recently. I'm not sure if this was introduced alongside them - I am looking into it now. Sorry for the lag!

josephdviviano commented 1 year ago

Sorry for the lag getting back you on this!

Do you still get this behaviour with https://github.com/GFNOrg/torchgfn/pull/149 ?

I've changed a lot of logic RE auto reward clipping, making these kinds of silent bugs less likely (if the user forgets to pass the correct kwarg) - if it is still an issue it will be the next thing I work on.

josephdviviano commented 9 months ago

All fixes from #147 and #149 are now merged into master, it would be appreciated to know if you still face this issue. Thank you!

josephdviviano commented 9 months ago

Hey @ermiaetemadi,

I've updated your example to work with the current state of the codebase. On my machine, I'm able to get far further than you using the default options (e.g., L=16).

1600: {'loss': 259.5113525390625, 'states_visited': 12808}                                                                                                     
 16%|██████████████████                                                                                              | 1610/10000 [2:12:31<11:16:52,  4.84s/it]

I'm not sure but I suspect the issue was resolved in one of the multiple previous PRs, if I was to guess, it is because we removed reward clipping by default, but I can't be sure.

Hopefully this is helpful and you find the library useful!

import torch

from tqdm import tqdm
import wandb
from argparse import ArgumentParser

from gfn.gym import DiscreteEBM
from gfn.gym.discrete_ebm import IsingModel
from gfn.gflownet import FMGFlowNet
from gfn.utils.modules import NeuralNet
from gfn.modules import DiscretePolicyEstimator
from gfn.utils.training import validate

def main(args):

    # Configs

    use_wandb = len(args.wandb_project) > 0
    if use_wandb:
        wandb.init(project=args.wandb_project)
        wandb.config.update(args)

    device =  "cpu"
    torch.set_num_threads(args.n_threads)
    hidden_dim = 512

    n_hidden = 2
    acc_fn = "relu"
    lr = 0.001
    lr_Z = 0.01
    validation_samples = 1000

    def make_J(L, coupling_constant):
        """Ising model parameters."""
        def ising_n_to_ij(L, n):
            i = n // L
            j = n - i * L
            return (i, j)

        N = L**2
        J = torch.zeros((N, N), device=torch.device(device))
        for k in range(N):
            for m in range(k):

                x1, y1 = ising_n_to_ij(L, k)
                x2, y2 = ising_n_to_ij(L, m)
                if x1 == x2 and abs(y2 - y1) == 1:
                    J[k][m] = 1
                    J[m][k] = 1
                elif y1 == y2 and abs(x2 - x1) == 1:
                    J[k][m] = 1
                    J[m][k] = 1

        for k in range(L):

            J[k*L][(k+1)*L - 1] = 1
            J[(k+1)*L - 1][k*L] = 1
            J[k][k+N-L] = 1
            J[k+N-L][k] = 1

        return coupling_constant * J

    # Ising model env
    N = args.L ** 2
    J = make_J(args.L, args.J)
    ising_energy = IsingModel(J)
    env = DiscreteEBM(N, alpha=1, energy=ising_energy, device_str=device)

    # Parametrization and losses
    pf_module = NeuralNet(
                    input_dim=env.preprocessor.output_dim,
                    output_dim=env.n_actions,
                    hidden_dim=hidden_dim,
                    n_hidden_layers=n_hidden,
                    activation_fn=acc_fn
                )

    pf_estimator = DiscretePolicyEstimator(pf_module, env.n_actions, env.preprocessor, is_backward=False)
    gflownet = FMGFlowNet(pf_estimator)
    optimizer = torch.optim.Adam(gflownet.parameters(), lr=1e-3)

    # Learning
    visited_terminating_states = env.States.from_batch_shape((0,))
    states_visited = 0
    for i in (pbar := tqdm(range(10000))):
        trajectories = gflownet.sample_trajectories(env, n_samples=8, off_policy=False)
        training_samples = gflownet.to_training_samples(trajectories)
        optimizer.zero_grad()
        loss = gflownet.loss(env, training_samples)
        loss.backward()
        optimizer.step()

        states_visited += len(trajectories)
        to_log = {"loss": loss.item(), "states_visited": states_visited}

        if i % 25 == 0:
            tqdm.write(f"{i}: {to_log}")

if __name__ == "__main__":

    # Comand-line arguments
    parser = ArgumentParser()

    parser.add_argument(
            "--n_threads",
            type=int,
            default=4,
            help="Number of threads used by PyTorch",
        )

    parser.add_argument(
            "-L",
            type=int,
            default=16,
            help="Lentgh of the grid",
        )

    parser.add_argument(
            "-J",
            type=float,
            default=0.44,
            help="J (Magnetic coupling constant)",
        )

    parser.add_argument(
            "--wandb_project",
            type=str,
            default="",
            help="Name of the wandb project. If empty, don't use wandb",
        )

    args = parser.parse_args()
    main(args)
ermiaetemadi commented 8 months ago

Sorry for my late response.

I ran the script with the latest version and it seems that the bug is fixed. I'm closing this issue as a result.

Thanks

josephdviviano commented 8 months ago

Happy to hear it ! Please let me know if you have future issues or suggestions for the project :)

Joseph (Mobile)

On Tue, Apr 2, 2024 at 17:06 Ermia Etemadi @.***> wrote:

Sorry for my late response.

I ran the script with the latest version and it seems that the bug is fixed. I'm closing this issue as a result.

Thanks

— Reply to this email directly, view it on GitHub https://github.com/GFNOrg/torchgfn/issues/136#issuecomment-2033098249, or unsubscribe https://github.com/notifications/unsubscribe-auth/AA7TL2R5IKRG67ZPBOSLGA3Y3MMV3AVCNFSM6AAAAAA4763IP6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAMZTGA4TQMRUHE . You are receiving this because you were assigned.Message ID: @.***>