rwth-i6 / returnn

The RWTH extensible training framework for universal recurrent neural networks
http://returnn.readthedocs.io/
Other
348 stars 130 forks source link

RF torch `lstm` fails with torch amp option. #1529

Closed LucaG1 closed 3 months ago

LucaG1 commented 3 months ago

Hi, I'm currently trying to train a transducer model using rf. I use the torch_amp="bfloat16" option from previous setups. In the predictor I use a rf.LayerNorm followed by rf.LSTM. I think this fails because the LayerNorm uses float32 and the LSTM float16. The lstm function of the rf pytorch backend fails here: https://github.com/rwth-i6/returnn/blob/bc0d5bf818d894c2b5d4e1676037228ef0f28041/returnn/torch/frontend/_backend.py#L2021

This is the error message: ValueError: tensor_raw_tensor_setter: tensor.dtype != raw_tensor.dtype, from tensor dtype 'float32' and raw_tensor dtype 'float16'

When I disable torch mixed precision I get past this error. However I'm not sure how this should be handled correctly. Maybe this is missing a cast? Or mixed precision can not be used in this case?

The config I used can be found here: /u/luca.gaudino/setups/2023-08-10--rf-librispeech/debug/train_rnnt_rf_ls960.config

You can run it via /u/luca.gaudino/bin/returnn_launcher_nocuda.sh /u/luca.gaudino/setups/2023-08-10--rf-librispeech/returnn/rnn.py /u/luca.gaudino/setups/2023-08-10--rf-librispeech/debug/train_rnnt_rf_ls960.config on a 24gb gpu node.

JackTemaki commented 3 months ago

This problem is not specific to returnn_frontend, but also happens with pure PyTorch. I solved this by wrapping the LSTM cell call with: with torch.autocast(device_type="cuda", enabled=False):

JackTemaki commented 3 months ago

Correction: this is only a problem when using the cell type LSTM, not with the sequence one, so in your case this really just might be a check problem in returnn_frontend

albertz commented 3 months ago

Yea, RF assumes float32 here but got float16. In a couple of other cases, I just overwrite the dtype with whatever Torch has returned, i.e. basically removing the check. E.g. see the softmax function. Basically just add this before you assign out.raw_tensor:

out.dtype = TorchBackend.get_dtype_name_raw(out_raw)

Do you want to commit this or should I?

albertz commented 3 months ago

I just pushed it now.

LucaG1 commented 3 months ago

Thanks, I think the same thing has to be done for the states. Should I just commit this or do you want to amend the previous commit?

albertz commented 3 months ago

Why did you reopen this? It's not fixed yet? You mean it also needs to be done for the states? Yes, just add another commit for this directly to master then.