Open Shiro-LK opened 4 years ago
Did you use apex? Can you share me the script and errors you use?
Hi
Thank you for your reply.
I am using torch.cuda.amp
for the half precision. It is available since pytorch 1.6 if I am not wrong. with autocast
allows to use half precision.
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
from models import *
import numpy as np
mel_bins = 64
fmin = 50
fmax = 14000
window_size = 1024
hop_size = 320
sample_rate=32000
model = Wavegram_Logmel_Cnn14(sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num=264).to("cuda:0")
model.eval()
x = torch.tensor( np.random.uniform(0,1, (1,32000*5)), dtype=torch.float32).to("cuda:0")
with torch.no_grad():
with autocast():
outputs = model(x)
print(outputs)
with torch.no_grad():
outputs = model(x)
print(outputs)
if you add in the forward after x = self.logmel_extractor(x)
that :
if x.isnan().sum()>0:
print("logmel", x.isnan().sum(), x.shape)
you can see that the nan value appeared after the layerLogmelFilterBank
Hi, I guess the problem is caused by the log operation in LogmelFilterBank. To address this problem, you could try:
1) Set is_log=False in the argument of LogmelFilterBank. If nan disappear, then we know the problem is caused by log.
2) Set is_log=True and amin=1e-6 in the argument of LogmelFilterBank.
See
https://github.com/qiuqiangkong/torchlibrosa/blob/master/torchlibrosa/stft.py#L389 for details. Please let me know if there are any questions!
Best wishes,
Qiuqiang
On Tue, 25 Aug 2020 at 04:59, Shiro-LK notifications@github.com wrote:
Hi Thank you for your reply. I am using torch.cuda.amp for the half precision. It is available since pytorch 1.6 if I am not wrong. with autocast allows to use half precision.
import torch from torch import nn from torch.nn import functional as F from torch.utils.data import DataLoader from torch.cuda.amp import autocast, GradScaler from models import * import numpy as np
mel_bins = 64 fmin = 50 fmax = 14000 window_size = 1024 hop_size = 320 sample_rate=32000 model = Wavegram_Logmel_Cnn14(sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num=264).to("cuda:0") model.eval()
x = torch.tensor( np.random.uniform(0,1, (1,32000*5)), dtype=torch.float32).to("cuda:0") with torch.no_grad(): with autocast(): outputs = model(x) print(outputs)
with torch.no_grad(): outputs = model(x) print(outputs)
if you add in the forward after x = self.logmel_extractor(x) that :
if x.isnan().sum()>0: print("logmel", x.isnan().sum(), x.shape)
you can see that the nan value appeared after the layer LogmelFilterBank
— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/qiuqiangkong/audioset_tagging_cnn/issues/17#issuecomment-679363134, or unsubscribe https://github.com/notifications/unsubscribe-auth/ADFXTSPPIJFIZ7QZAWSJF53SCLIEDANCNFSM4QGXQ5IQ .
Hi,
Thanks for you answer. It is weird, the issue seems to be still here.
I have change the code and used only the spectrogram and LogmelFilterbank function:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
import numpy as np
from torchlibrosa.stft import Spectrogram, LogmelFilterBank
mel_bins = 64
fmin = 50
fmax = 14000
window_size = 1024
hop_size = 320
sample_rate=32000
window = 'hann'
center = True
pad_mode = 'reflect'
ref = 1.0
amin = 1e-10
top_db = None
spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size,
win_length=window_size, window=window, center=center, pad_mode=pad_mode,
freeze_parameters=True).to("cuda:0")
# Logmel feature extractor
logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size,
n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, is_log=True,
freeze_parameters=True).to("cuda:0")
x = torch.tensor( np.random.uniform(0,1, (1,32000*5)), dtype=torch.float32).to("cuda:0")
with torch.no_grad():
with autocast():
x1=spectrogram_extractor(x)
#print(torch.isinf(x1).sum())
print(x1)
outputs = logmel_extractor(x1)
print(outputs)
with torch.no_grad():
x1=spectrogram_extractor(x)
print(x1)
outputs = logmel_extractor(x1)
print(outputs)
It seems the issue comes from the matmul, so I suspect an issue between librosa and autocast but I am maybe wrong
UPDATE : I feel the best way to overcome this issue, is to desactivate autocast for this layer, as there is no trained weights it should work without any issue.
with autocast(False):
outputs = logmel_extractor(x1)
Thanks for the solution!!
@Shiro-LK where you able to train with mixed precision?
@jonnor yes, but you need to add this part in the models code :
with autocast(False):
outputs = logmel_extractor(x1)
because the half precision with the logmel operation make appeared NaN value or inf value so it makes the training impossible otherwise. So you can just desactivate the half precision for this specific operation.
Another possible solution is to use this: https://github.com/NVIDIA/apex. But I have not tried this.
Hello,
I thank you for sharing the weights and experiment of your papers, it is a very good work and very helpful.
I am experimenting your Wavegram_Logmel_Cnn14 model on a custom dataset and I have seen some issue when I am using mixed precision in pytorch 1.6 with the layer LogmelFilterBank. In fact, I get sometimes nan values in the forward output of this layer which makes nan value in the loss function later. I was wondering if you have an idea why ? I do not have this issue when I am not using mixed precision.