espnet / espnet

End-to-End Speech Processing Toolkit
https://espnet.github.io/espnet/
Apache License 2.0
8.31k stars 2.16k forks source link

BSRNN with `causal=True` is not causal #5900

Closed philgzl closed 2 days ago

philgzl commented 3 days ago

It seems to me that BSRNN with causal=True is not causal. Forward passing a tensor with NaNs in the last time frame results in a tensor with NaNs only.

import torch

from espnet2.enh.layers.bsrnn import BSRNN

net = BSRNN(causal=True)
net.eval()

x = torch.randn(1, 1000, 481, 2)
x[:, -1, :, :] = float('nan')
x = net(x)
assert not x.isnan().all()  # fails

The first non-causal operation seems to be the normalization layers in the band-split module, which calculate statistics over time frames.

sw005320 commented 3 days ago

Thanks for pointing out the issue. @Emrys365, can you check this issue?

Emrys365 commented 3 days ago

Hi @philgzl, to make the model fully causal, you will also need to set the norm_type to either "cLN" or "cfLN" in addition to causal=True.

philgzl commented 3 days ago

Thanks @Emrys365. Can you confirm that the configurations marked as causal in your paper have this modification? The norm_type option does not seem to be specified in the configuration files provided in Emrys365/se-scaling.

Emrys365 commented 2 days ago

No, the paper only employed the causal=True argument at that time (as we overlooked the normalization part unfortunately), so the models are not strictly causal (all operations but the normalization layers are causal). The conclusions should be similar though. But the performance of a truly causal model should be slightly worse than reported.

philgzl commented 2 days ago

Cool, thanks for clarifying!