lucidrains / audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
MIT License
2.39k stars 255 forks source link

padding issue for CausalConv1d #166

Closed YoungloLee closed 1 year ago

YoungloLee commented 1 year ago

https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/soundstream.py#L303-L314 In CausalConv1d class, the number of padding should be (dilation (kernel_size - 1) + 1 - stride), not just (dilation (kernel_size - 1)). Without this correction, the soundstream model does not converge. Hope this helps you.

class CausalConv1d(nn.Module):
    def __init__(self, chan_in, chan_out, kernel_size, **kwargs):
        super().__init__()
        kernel_size = kernel_size
        dilation = kwargs.get('dilation', 1)
        stride = kwargs.get('stride', 1)
        self.causal_padding = dilation * (kernel_size - 1) + 1 - stride
        self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, **kwargs)

    def forward(self, x):
        x = F.pad(x, (self.causal_padding, 0), mode='reflect')
        return self.conv(x)
amitaie commented 1 year ago

Hey, can you share your training graphs\results with your correction and without?

and also, why did you use mode='reflect' ?

YoungloLee commented 1 year ago

Hey, can you share your training graphs\results with your correction and without?

and also, why did you use mode='reflect' ?

Without this modification, SI-SDR metric does not drop below 30dB because of the time-alignment mismatch (whenever down sampling (stride > 1) occurs). As the time-domain reconstruction loss (l1 + l2) scale is difficult to interpret, just take a look at my training SI-SDR curve. (batch size = 32 and training segment length = 1 sec, very early training stage ~12k)

image

Also, i just set the padding mode to be 'reflect' following FAIR's EnCodec implementation (https://github.com/facebookresearch/encodec/blob/6e8d7eda6fff5b0d589d64f063610c7f6044963e/encodec/modules/seanet.py#L95).

lucidrains commented 1 year ago

@YoungloLee thank you dearly for this! i believe you are correct and this is a huge misstep on my part :pray:

amitaie commented 1 year ago

Also, i just set the padding mode to be 'reflect' following FAIR's EnCodec implementation

maybe i'm totally worng but doesn't reflect padding mode make it not streaming?

YoungloLee commented 1 year ago

Also, i just set the padding mode to be 'reflect' following FAIR's EnCodec implementation

maybe i'm totally worng but doesn't reflect padding mode make it not streaming?

I think it does not matter for streaming.

Dannynis commented 1 year ago

Hey, doesnt it affects the CausalTransposedConv1d as well if so ?