I want to use this to train on TSP, and I successfully produce the desired dataset/instances. But now, it seems like, after some iterations (less than 10), the loss becomes nan. I try to solve this by lowering the learning rate, but it doesn't help. And I have checked the document for torch.nn.functional.cross_entropy(), which said the input (in this case, logits) can be un-normalized, so I haven't tried to normalize it. Are there any other possible tips that may help solve this problem? The code snippets which cause the problem (I think) are as follows:
And the output is as follows
where I print the following three things:
It seems like the torch.nn.functional.cross_entropy() first becomes nan, and then on the next iteration, everything becomes nan due to one bad update. Any helps is appreciated!
p.s. I'm not entirely sure whether this information is useful, but I'm working on TSP where I model this into lp, which makes the lp size pretty big. For an instance with 15 nodes, it works fine in 7 epochs, and if I simply skip the problematic instance (if the entropy is nan, just skip that instance), then it can be trained successfully (getting 100 epochs with acc@1 boosts from 33% to 66%). But if I use the same trick, then for tsp25, everything is skipped... So I think this may relate to the problem size?
Describe the bug
I followed the
branching-imitation/example.ipynb
while creating my own data instances. And I getnan
while training.Setting
I want to use this to train on TSP, and I successfully produce the desired dataset/instances. But now, it seems like, after some iterations (less than 10), the loss becomes
nan
. I try to solve this by lowering the learning rate, but it doesn't help. And I have checked the document fortorch.nn.functional.cross_entropy()
, which said the input (in this case,logits
) can be un-normalized, so I haven't tried to normalize it. Are there any other possible tips that may help solve this problem? The code snippets which cause the problem (I think) are as follows:And the output is as follows where I print the following three things: It seems like the
torch.nn.functional.cross_entropy()
first becomesnan
, and then on the next iteration, everything becomesnan
due to one bad update. Any helps is appreciated!p.s. I'm not entirely sure whether this information is useful, but I'm working on TSP where I model this into lp, which makes the lp size pretty big. For an instance with 15 nodes, it works fine in 7 epochs, and if I simply skip the problematic instance (if the entropy is
nan
, just skip that instance), then it can be trained successfully (getting 100 epochs with acc@1 boosts from 33% to 66%). But if I use the same trick, then for tsp25, everything is skipped... So I think this may relate to the problem size?