Open urimerhav opened 3 years ago
Alright another data point to add here: I've confirmed that the only way the model looks 900 samples into the future is via the STFT-ISTFT round trip.
Potentially the contribution of ISTFT to leaking info from the future is very limited, but it's definitely an unintended leak that shows that the model as-is does have some effect from future samples up to 900 samples into the future. I'd suggest implementing ISTFT in a strictly causal way to confirm this model actually meets the 37.5 msec limitation with equivalent results.
It's easy enough to show the model's output at time t is effected by samples from t+900 (implying 56.25msec anti-causality).
Simply put, if we enter a signal like [1,1,1...,1,inf,inf,inf], and the first inf comes at sample N, the output becomes invalid at sample N-900, which implies 56.25 delay.
net = DCCRN(rnn_units=256, masking_mode='E', use_clstm=False, kernel_num=[32, 64, 128, 256, 256, 256]) canary_input = torch.ones([1, 16000*2]).clamp_(-1, 1) * 0.5 canary_input[0,-leading_n:] = np.inf out_canary = net(canary_input)[1].detach().numpy() first_invalid = np.where(np.isnan(out_canary) == True)[1][0] future_effect_size = canary_input[0].shape[0] -leading_n- first_invalid print('model samples into the future is',future_effect_size)
I've done some digging and it seems that this has to do with STFT working in windows of 400 and 100 skips.
Hi,I use the code you provided to calculate how many future samples are used in this network. I set leading_n=1, and get the result of future_effect_size=31999. It seems the whole output of the network is invalid. Do you have any idea of this problem?
I think you perhaps need to make sure your model under eval mode, otherwise normalization layer would be updated by inf value. The other possibility is that you using the utterance level normalization method.
I think you perhaps need to make sure your model under eval mode, otherwise normalization layer would be updated by inf value. The other possibility is that you using the utterance level normalization method.
Yes, I didn't make my model under eval mode, now it's solved! Thank you very much!
It's easy enough to show the model's output at time t is effected by samples from t+900 (implying 56.25msec anti-causality).
Simply put, if we enter a signal like [1,1,1...,1,inf,inf,inf], and the first inf comes at sample N, the output becomes invalid at sample N-900, which implies 56.25 delay.
I've done some digging and it seems that this has to do with STFT working in windows of 400 and 100 skips.