iancovert / Neural-GC

Granger causality discovery for neural networks.
MIT License
197 stars 50 forks source link

cLSTM train_model_ista() training loop #14

Open lndip opened 1 month ago

lndip commented 1 month ago

Hi @iancovert, thank you for your work and also the code!

I am going through the code in cLSTM.py and have some questions about how train_model_ista() is written. I wonder why there should be the first calculation of the smooth error outside the loop for it in range(max_iter):. Would it be the same if the function is written as

def train_model_ista(clstm, X, context, lr, max_iter, lam=0, lam_ridge=0,
                     lookback=5, check_every=50, verbose=1):
    p = X.shape[-1]
    loss_fn = nn.MSELoss(reduction='mean')
    train_loss_list = []

    # Set up data.
    X, Y = zip(*[arrange_input(x, context) for x in X])
    X = torch.cat(X, dim=0)
    Y = torch.cat(Y, dim=0)

    # For early stopping.
    best_it = None
    best_loss = np.inf
    best_model = None

    for it in range(max_iter):
        # Calculate smooth error.
        pred = [clstm.networks[i](X)[0] for i in range(p)]
        loss = sum([loss_fn(pred[i][:, :, 0], Y[:, :, i]) for i in range(p)])
        ridge = sum([ridge_regularize(net, lam_ridge) for net in clstm.networks])
        smooth = loss + ridge

        # Take gradient step.
        smooth.backward()
        for param in clstm.parameters():
            param.data -= lr * param.grad

        # Take prox step.
        if lam > 0:
            for net in clstm.networks:
                prox_update(net, lam, lr)

        clstm.zero_grad()

        # Check progress.
        if (it + 1) % check_every == 0:
            # Add nonsmooth penalty.
            nonsmooth = sum([regularize(net, lam) for net in clstm.networks])
            mean_loss = (smooth + nonsmooth) / p
            train_loss_list.append(mean_loss.detach())

            if verbose > 0:
                print(('-' * 10 + 'Iter = %d' + '-' * 10) % (it + 1))
                print('Loss = %f' % mean_loss)
                print('Variable usage = %.2f%%'
                      % (100 * torch.mean(clstm.GC().float())))

            # Check for early stopping.
            if mean_loss < best_loss:
                best_loss = mean_loss
                best_it = it
                best_model = deepcopy(clstm)
            elif (it - best_it) == lookback * check_every:
                if verbose:
                    print('Stopping early')
                break

    # Restore best model.
    restore_parameters(clstm, best_model)

    return train_loss_list
lndip commented 4 weeks ago

And for minibatch implementation, do you think that this loop is plausible?

    for it in range(max_iter):
        train_loss_it = []

        for batch in train_dataloader:
            X,Y = batch
            X = X.float().to(device)
            Y = Y.float().to(device)

            clstm.zero_grad()

            # Calculate smooth error.
            pred = [clstm.networks[i](X)[0] for i in range(p_out)]
            loss = sum([loss_fn(pred[i][:, :, 0], Y[:, :, i]) for i in range(p_out)])
            ridge = sum([ridge_regularize(net, lam_ridge) for net in clstm.networks])
            smooth = loss + ridge

            # Take gradient step.
            smooth.backward()
            for param in clstm.parameters():
                param.data -= lr * param.grad

            # Take prox step.
            if lam > 0:
                for net in clstm.networks:
                    prox_update(net, lam, lr)

            nonsmooth = sum([regularize(net, lam) for net in clstm.networks])
            mean_loss = (smooth + nonsmooth) / p_out
            train_loss_it.append(mean_loss.detach())

        # log epoch loss
        mean_train_loss = np.mean(train_loss_it)
        train_lost_list.append(mean_train_loss)

        # Check progress.
        if (it + 1) % check_every == 0:
            if verbose > 0:
                print(('-' * 10 + 'Iter = %d' + '-' * 10) % (it + 1))
                print('Loss = %f' % mean_train_loss)
                print('Variable usage = %.2f%%'
                      % (100 * torch.mean(clstm.GC(threshold=0).float())))
iancovert commented 4 weeks ago

Hi, thanks for checking out the code. For your first question: the initial calculation of smooth is a small efficiency hack - each step requires this component of the error both for the backward pass (here) and logging the error (here a couple lines later), and I figured we might as well re-use the error calculation from each error log in the subsequent backward pass. The first calculation of smooth here is so we have it ready for the first backward pass, but I agree this is a bit unusual. One alternative would be to log the error before each step; another would be to report the error on a held-out validation set, but we didn't have that for most of our experiments. Anyway, what you've shown above is a bit off because it adds the smooth error from before the update step to the non-smooth error from after the step.

For your second question: what you've shown for minibatch optimization seems reasonable, you can just sample X, Y for each train loss calculation. For the loss you're reporting at each progress check, the mean loss over the course of the epoch is one approach, another would be to do a separate pass over the train set or a held-out val set.

lndip commented 4 weeks ago

Thank you for your answer! I saw the mismatched between smooth and nonsmooth the snippet I added now! Also, may I ask based on which criteria that the hyperparameters (lam, lam_ridge, or GC_threshold in Adam optimization) were chosen. Were they based on the experimental results?

iancovert commented 3 weeks ago

Tuning those hyperparameters could be a bit tricky, and we took a pretty simple approach in the paper. In our experiments, we didn't focus on finding a single best setting and instead performed our evaluation based on results we got with increasingly strong sparsity penalties (specifically for lam). So we fixed lam_ridge to a single value (a small value that we didn't tune for simplicity, 1e-2), and we fit models with a range of lam values. We manually tuned the lam range so that no features were selected for the largest value, and all features were selected for the lowest value. Our AUROC/AUPR evaluations are based on the confusion matrix of positives and negatives observes for each lam value, meaning that each model becomes a single point on the ROC/PR curves.

As for GC_threshold, we didn't tune this either - in the few experiments we did with Adam we just fixed it to a small value (we apparently didn't put it in the paper, but I believe it was 1e-2 or 1e-3).