rwth-i6 / returnn

The RWTH extensible training framework for universal recurrent neural networks
http://returnn.readthedocs.io/
Other
350 stars 131 forks source link

Ignore a single broken gradient #1568

Open JackTemaki opened 1 week ago

JackTemaki commented 1 week ago

In my current language model training I sometimes get "nan" gradients, which break the training. Surprisingly, just restarting the training from the last checkpoint is often enough uncertainty to resume training.

Here people discussed something like:

        valid_gradients = True
        for name, param in self.named_parameters():
            if param.grad is not None:
                valid_gradients = not (torch.isnan(param.grad).any() or torch.isinf(param.grad).any())
                if not valid_gradients:
                    break
        if not valid_gradients:
            print(f'detected inf or nan values in gradients. not updating model parameters')
            self.zero_grad()

I think it would be a good idea to have this as a configurable option for the updater. Preferably with a "limit", so that it still crashes after e.g. 5 broken updates.

albertz commented 1 week ago

I sometimes get "nan" gradients, which break the training.

You mean after that, the model parameters itself become nan, i.e. the model is broken?

Surprisingly, just restarting the training from the last checkpoint is often enough uncertainty to resume training.

Uncertainty? You mean restarting from the last checkpoint solves the issue, i.e. getting nan is non-deterministic, and rare, and after such restart, you are usually lucky that you don't get nan anymore, or only much later?

We could also implement such automatic restart logic. It would be another approach than what you suggest afterwards, i.e. to always check for non-finite grads.

I'm not sure what approach would be better. It probably depends also on how often you get this. (E.g. I personally have never gotten this problem so far.)

JackTemaki commented 1 week ago

I do not want such automatic restart logic, because this wastes computation. For testing I now implemented here that it skips the update when the result of the gradient clipping is NaN or Inf. I do this for the grad_clip_norm_ right now, and I am not sure how grad_clip_value_ behaves with Inf/NaN values (maybe it automatically corrects them already).

This specific problem appeared for me only for LSTM-LM training so far, not anywhere else (ASR, TTS).

In my current training after 2 full epochs, I now get a skip every 40k update steps. Before it did not happen at all.