ds4dm / ecole

Extensible Combinatorial Optimization Learning Environments
https://www.ecole.ai
BSD 3-Clause "New" or "Revised" License
318 stars 68 forks source link

Getting `nan` while training #324

Closed sleepymalc closed 2 years ago

sleepymalc commented 2 years ago

Describe the bug

I followed the branching-imitation/example.ipynb while creating my own data instances. And I get nan while training.

Setting

I want to use this to train on TSP, and I successfully produce the desired dataset/instances. But now, it seems like, after some iterations (less than 10), the loss becomes nan. I try to solve this by lowering the learning rate, but it doesn't help. And I have checked the document for torch.nn.functional.cross_entropy(), which said the input (in this case, logits) can be un-normalized, so I haven't tried to normalize it. Are there any other possible tips that may help solve this problem? The code snippets which cause the problem (I think) are as follows:

def process(policy, data_loader, top_k=[1, 3, 5, 10], optimizer=None):
    mean_loss = 0
    mean_kacc = np.zeros(len(top_k))
    mean_entropy = 0

    n_samples_processed = 0
    with torch.set_grad_enabled(optimizer is not None):
        for batch in data_loader:
            batch = batch.to(device)

           ######## potentially problematic part starts

            logits = policy(batch.constraint_features, batch.edge_index, batch.edge_attr, batch.variable_features)
            logits = pad_tensor(logits[batch.candidates], batch.nb_candidates)

            cross_entropy_loss = F.cross_entropy(logits, batch.candidate_choices, reduction='mean')

            print("cross entropy loss", cross_entropy_loss)

            entropy = (-F.softmax(logits, dim=-1) * F.log_softmax(logits, dim=-1)).sum(-1).mean()
            loss = cross_entropy_loss - entropy_bonus * entropy

           ######## potentially problematic part ends

            if optimizer is not None:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            true_scores = pad_tensor(batch.candidate_scores, batch.nb_candidates)
            true_bestscore = true_scores.max(dim=-1, keepdims=True).values

            kacc = []
            for k in top_k:
                if logits.size()[-1] < k:
                    kacc.append(1.0)
                    continue
                pred_top_k = logits.topk(k).indices
                pred_top_k_true_scores = true_scores.gather(-1, pred_top_k)
                accuracy = (pred_top_k_true_scores == true_bestscore).any(dim=-1).float().mean().item()
                kacc.append(accuracy)
            kacc = np.asarray(kacc)

            mean_loss += cross_entropy_loss.item() * batch.num_graphs
            mean_entropy += entropy.item() * batch.num_graphs
            mean_kacc += kacc * batch.num_graphs
            n_samples_processed += batch.num_graphs

    mean_loss /= n_samples_processed
    mean_kacc /= n_samples_processed
    mean_entropy /= n_samples_processed
    return mean_loss, mean_kacc, mean_entropy

And the output is as follows image where I print the following three things: image It seems like the torch.nn.functional.cross_entropy() first becomes nan, and then on the next iteration, everything becomes nan due to one bad update. Any helps is appreciated!

p.s. I'm not entirely sure whether this information is useful, but I'm working on TSP where I model this into lp, which makes the lp size pretty big. For an instance with 15 nodes, it works fine in 7 epochs, and if I simply skip the problematic instance (if the entropy is nan, just skip that instance), then it can be trained successfully (getting 100 epochs with acc@1 boosts from 33% to 66%). But if I use the same trick, then for tsp25, everything is skipped... So I think this may relate to the problem size?

sleepymalc commented 2 years ago

Ok, apparently, I figure it out somehow. I'll close this issue!