flashlight / wav2letter

Facebook AI Research's Automatic Speech Recognition Toolkit
https://github.com/facebookresearch/wav2letter/wiki
Other
6.37k stars 1.01k forks source link

Help porting some layers to pytorch #744

Closed lunixbochs closed 3 years ago

lunixbochs commented 4 years ago

I'm working on a pytorch loader for my model format (https://github.com/facebookresearch/wav2letter/issues/718)

I have conv_glu models working by trial/error, but I'm confused on TDS. Any help would be appreciated!

TDS has an inner view/reorder.

class View(nn.Module):
    def __init__(self, shape: Sequence[int]):
        super().__init__()
        self.shape = shape

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        shape = [n if n != 0 else input.shape[i]
                 for i, n in enumerate(self.shape)]
        # print('here', input.shape, self.shape, shape)
        return input.view(*shape)

    def __repr__(self) -> str:
        return 'View({})'.format(list(self.shape))

# arrayfire is: HWCN column major
# pytorch is:   NCHW row major
def af_to_torch_dim(dim: int) -> int:
    return 3 - dim
torch_to_af_dim = af_to_torch_dim

def af_to_torch_dims(*dims: int) -> int:
    return [af_to_torch_dim(dim) for dim in dims]
torch_to_af_dims = af_to_torch_dims

def af_to_torch_shape(*shape: int) -> Sequence[int]:
    return torch.Size(list(reversed(shape)))
torch_to_af_shape = af_to_torch_shape
    elif t == 'TDS':
        args = line.split(' ')
        channels, kernel_size, width, inner_linear_dim, right_pad, layer_norm_include_time = map(int, args[:3] + args[4:])
        dropout = float(args[3])

        layers = []
        conv = []
        if right_pad:
            total_pad = kernel_size - 1
            if right_pad > total_pad:
                raise ValueError('right pad exceeds SAME pad size for TDS Block')
            padding = (total_pad - right_pad, right_pad, 0, 0)
            conv.append(nn.ConstantPad2d(padding, 0))

        conv.append(nn.Conv2d(channels, channels, kernel_size=(1, kernel_size)))
        # TODO: load weights
        if not right_pad:
            conv.append(SamePad(conv[-1]))
        conv.append(nn.ReLU())
        if dropout > 0:
            conv.append(nn.Dropout(dropout))

        linear_dim = channels * width
        if inner_linear_dim == 0:
            inner_linear_dim = linear_dim

        fc = []
        fc.append(Reorder(*af_to_torch_dims(2, 1, 0, 3)))
        fc.append(View(af_to_torch_shape(linear_dim, -1, 1, 0)))

        fc.append(nn.Linear(linear_dim, inner_linear_dim))
        fc.append(nn.ReLU())
        if dropout > 0:
            fc.append(nn.Dropout(dropout))
        fc.append(nn.Linear(inner_linear_dim, linear_dim))
        fc.append(nn.ReLU())

        fc.append(View(af_to_torch_shape(channels, width, -1, 0)))
        fc.append(Reorder(*af_to_torch_dims(2, 1, 0, 3)))
        if dropout > 0:
            fc.append(nn.Dropout(dropout))

        layers += conv
        # FIXME:
        if layer_norm_include_time:
            pass # TODO: nn.LayerNorm()
        else:
            pass # TODO: nn.LayerNorm()
        layers += fc
        if layer_norm_include_time:
            pass # TODO: nn.LayerNorm()
        else:
            pass # TODO: nn.LayerNorm()
        return layers

I'm working with this arch:

V -1 60 1 0
SAUG 60 18 2 100 0.05 2
PD 0 5 3
C 1 15 10 2 0 1
R
DO 0.100000
LN  1 2 
TDS 15 9 60 0.100000 0 1 0
TDS 15 9 60 0.100000 0 1 0
PD 0 7 1
C 15 19 10 2 0 1
R
DO 0.100000
LN  1 2 
TDS 19 9 60 0.100000 0 1 0
TDS 19 9 60 0.100000 0 1 0
TDS 19 9 60 0.100000 0 1 0
PD 0 9 1
C 19 23 12 2 0 1
R
DO 0.100000
LN  1 2 
TDS 23 11 60 0.100000 0 1 0
TDS 23 11 60 0.100000 0 1 0
TDS 23 11 60 0.100000 0 1 0
TDS 23 11 60 0.100000 0 0 0
PD 0 10 0
C 23 27 11 1 0 1
R
DO 0.100000
LN  1 2 
TDS 27 11 60 0.100000 0 0 0
TDS 27 11 60 0.100000 0 0 0
TDS 27 11 60 0.100000 0 0 0
TDS 27 11 60 0.100000 0 0 0
TDS 27 11 60 0.100000 0 0 0
RO 2 1 0 3
V 1620 -1 1 0
L 1620 4998
V 4998 0 -1 1

These are the first few pytorch model layers:

Sequential(
  (0): View([0, 1, 60, -1])
  (1): ConstantPad1d(padding=[5, 3], value=0)
  (2): Conv2d(1, 15, kernel_size=(1, 10), stride=(1, 2), padding=(1, 0))
  (3): ReLU()
  (4): Dropout(p=0.1, inplace=False)
  (5): ConstantPad2d(padding=(7, 1, 0, 0), value=0)
  (6): Conv2d(15, 15, kernel_size=(1, 9), stride=(1, 1))
  (7): ReLU()
  (8): Dropout(p=0.1, inplace=False)
  (9): Reorder(1, 2, 3, 0)
  (10): View([0, 1, -1, 900])
  (11): Linear(in_features=900, out_features=900, bias=True)
  (12): ReLU()
  (13): Dropout(p=0.1, inplace=False)
  (14): Linear(in_features=900, out_features=900, bias=True)
  (15): ReLU()
  (16): View([0, -1, 60, 15])
  (17): Reorder(1, 2, 3, 0)
  (18): Dropout(p=0.1, inplace=False)
  (19): ConstantPad2d(padding=(7, 1, 0, 0), value=0)

Here's a forward pass:

[+] running forward pass
    input:  torch.Size([1, 60, 198])
[-] layer View([0, 1, 60, -1])
    output: torch.Size([1, 1, 60, 198])
[-] layer ConstantPad1d(padding=[5, 3], value=0)
    output: torch.Size([1, 1, 60, 206])
[-] layer Conv2d(1, 15, kernel_size=(1, 10), stride=(1, 2), padding=(1, 0))
    output: torch.Size([1, 15, 62, 99])
[-] layer ReLU()
    output: torch.Size([1, 15, 62, 99])
[-] layer Dropout(p=0.1, inplace=False)
    output: torch.Size([1, 15, 62, 99])
[-] layer ConstantPad2d(padding=(7, 1, 0, 0), value=0)
    output: torch.Size([1, 15, 62, 107])
[-] layer Conv2d(15, 15, kernel_size=(1, 9), stride=(1, 1))
    output: torch.Size([1, 15, 62, 99])
[-] layer ReLU()
    output: torch.Size([1, 15, 62, 99])
[-] layer Dropout(p=0.1, inplace=False)
    output: torch.Size([1, 15, 62, 99])
[-] layer Reorder(1, 2, 3, 0)
    output: torch.Size([15, 62, 99, 1])
[-] layer View([0, 1, -1, 900])

Traceback (most recent call last):
  File "model.py", line 385, in <module>
    emissions = w2l.forward(frames)
  File "model.py", line 305, in forward
    input = layer(input)
  File "/Users/aegis/Library/Python/3.7/lib/python/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "model.py", line 21, in forward
    return input.view(*shape)
RuntimeError: shape '[15, 1, -1, 900]' is invalid for input of size 92070
lunixbochs commented 4 years ago

These are the shapes from a TDS block in wav2letter

tds input shape [ 111 60 15]
fc input shape [ 111 60 15] 
Reorder (2,1,0,3)
    shape [ 15 60 111]
View (900 -1 1 0)
    shape [ 900 111]
Linear (900->900) (with bias)
    shape [ 900 111]
ReLU
    shape [ 900 111]
Dropout (0.100000)
    shape [ 900 111]
Linear (900->900) (with bias)
    shape [ 900 111]
View (15 60 -1 0)
    shape [ 15 60 111]
Reorder (2,1,0,3)
    shape [ 111 60 15] 
Dropout (0.100000)
    shape [ 111 60 15]
    tds output shape [ 111 60 15]

vs my pytorch forward pass (on different audio)

[+] running forward pass
    input:  torch.Size([1, 60, 198])
[-] layer View([0, 1, 60, -1])
    output: torch.Size([1, 1, 60, 198])
[-] layer ConstantPad1d(padding=[5, 3], value=0)
    output: torch.Size([1, 1, 60, 206])
[-] layer Conv2d(1, 15, kernel_size=(1, 10), stride=(1, 2), padding=(1, 0))
    output: torch.Size([1, 15, 62, 99])
[-] layer ReLU()
    output: torch.Size([1, 15, 62, 99])
[-] layer Dropout(p=0.1, inplace=False)
    output: torch.Size([1, 15, 62, 99])
[-] layer ConstantPad2d(padding=(7, 1, 0, 0), value=0)
    output: torch.Size([1, 15, 62, 107])
[-] layer Conv2d(15, 15, kernel_size=(1, 9), stride=(1, 1))
    output: torch.Size([1, 15, 62, 99])
[-] layer ReLU()
    output: torch.Size([1, 15, 62, 99])
[-] layer Dropout(p=0.1, inplace=False)
    output: torch.Size([1, 15, 62, 99])
[-] layer Reorder(1, 2, 3, 0)
    output: torch.Size([15, 62, 99, 1])
[-] layer View([0, 1, -1, 900])

Traceback (most recent call last):
  File "model.py", line 385, in <module>
    emissions = w2l.forward(frames)
  File "model.py", line 305, in forward
    input = layer(input)
  File "/Users/aegis/Library/Python/3.7/lib/python/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "model.py", line 21, in forward
    return input.view(*shape)
RuntimeError: shape '[15, 1, -1, 900]' is invalid for input of size 92070
tlikhomanenko commented 4 years ago

Hey @lunixbochs,

  1. Dimension ordering:

About conv GLU layers with weight norm I guess this will be helpful for you https://github.com/facebookresearch/wav2letter/tree/master/recipes/models/utilities/convlm_serializer (how we did import trained fairseq conv glu into w2l bin), like https://github.com/facebookresearch/wav2letter/blob/master/recipes/models/utilities/convlm_serializer/Utils.cpp#L107 and https://github.com/facebookresearch/wav2letter/blob/master/recipes/models/utilities/convlm_serializer/Utils.cpp#L114.

The thing you need to have in mind is that saved af::array is column major. So if you have tensor in pytorch with shape HxWxCxN in row-major ordering the it will be loaded into arrayfire as NxCxWxH tensor and vice versa. Let me know if this what you need to solve the issue.

  1. LayerNorm

LN 0 1 2 will compute normalizations along all axis except batch. The thing is that pytorch has another oredering of input tensors for the same operations, for example you can implement linear layer in two ways: y = Ax or y = xA. This one of the differences of pytorch and flashlight. So you can check input tensors for each operation (linear, conv, etc.). What you need to do is just have pytorch tensors but do the same operations, results will be the same, you just need to transpose/reorder things properly to have the same computations. For example to apply layer norm in TDS (I guess in pytorch it is done for last axes because of efficiency) you just need to call pytorch layer norm with input in format BxHxWxC (or any permutation of last 3 axes) with normalization along 3 last axes. This will be equivalent to the flashlight operation.

About TDS block view operation - let me recheck this, will come back soon on this.

lunixbochs commented 4 years ago

So for arrayfire TDS:

tds input shape [ 111 60 15]
fc input shape [ 111 60 15] 
Reorder (2,1,0,3)
    shape [ 15 60 111]
View (900 -1 1 0)
    shape [ 900 111]
Linear (900->900) (with bias)
    shape [ 900 111]
ReLU
    shape [ 900 111]
Dropout (0.100000)
    shape [ 900 111]
Linear (900->900) (with bias)
    shape [ 900 111]
View (15 60 -1 0)
    shape [ 15 60 111]
Reorder (2,1,0,3)
    shape [ 111 60 15] 
Dropout (0.100000)
    shape [ 111 60 15]
    tds output shape [ 111 60 15]

The weights are going in as HWCN, so

H=111 W=60 C=15 N=1
Reorder (2 1 0 3) swaps height and channel
H=15 W=60 C=111 N=1
View (900 -1 1 0) makes (H=900, C=1, N=N), and W = total / (H*C*N)
H=900 W=111 C=1 N=1
Then there's a Linear 900->inner->900 block which preserves the H size.
View (15 60 -1 0) makes (H=15 W=60 N=N) and C = total / (H*W*N)
H=15 W=60 C=111 N=1
Reorder (2,1,0,3) swaps height and channel again
H=111 W=60 C=15 N=1

and pytorch (NCHW) rough shape notes:

input: torch.Size([1, 15, 62, 99])
# at this point, H and W appear to be switched from flashlight
# also H is incorrectly 62 instead of 60 for some reason (I think accidental padding on previous layer)
N=1 C=15 H=62 W=99 
Reorder (1 2 3 0) # seems wrong, should be 0 2 1 3 to swap height and channel or 0 3 2 1 to swap W/C
N=1 C=99 H=62 W=15
View (0 1 -1 900) # H/W are swapped, should make (H=900, C=1, N=N), and W = total / (H*C*N)
N=1 C=1 H=900 W=~102
Linear 900 -> inner -> 900
View (0 -1 60 15) # H/W are swapped, should make (H=15, W=60, N=N), and C = total / (H*W*N)
N=1 C=99 H=60 W=15
Reorder (1, 2, 3, 0) # seems wrong, should be 0 2 1 3 to swap height and channel or 0 3 2 1 to swap W/C
N=1 C=15 H=60 W=99

Oh! Is row-major "width" in pytorch the same axis as column-major "height" in arrayfire? That would explain some of my confusion.

lunixbochs commented 4 years ago

Ok, I got LayerNorm working, however the fully connected section of TDS is broken with my current import strategy.

The input shape into this TDS block is torch.Size([1, 15, 60, 187]) This is what a TDS block looks like:

TDS 15 9 60 0.100000 0 1 0

Sequential(
  (0): Sequential(
    (0): ConstantPad2d(padding=(7, 1, 0, 0), value=0)
    (1): Conv2d(15, 15, kernel_size=(1, 9), stride=(1, 1))
    (2): ReLU()
    (3): Dropout(p=0.1, inplace=False)
  )
  (1): InnerLayerNorm()
  (2): Sequential(
    (0): Reorder(0, 3, 1, 2)
    (1): View([0, 1, -1, 900])
    (2): Linear(in_features=900, out_features=900, bias=True)
    (3): ReLU()
    (4): Dropout(p=0.1, inplace=False)
    (5): Linear(in_features=900, out_features=900, bias=True)
    (6): ReLU()
    (7): View([0, -1, 15, 60])
    (8): Reorder(0, 2, 3, 1)
    (9): Dropout(p=0.1, inplace=False)
  )
  (3): InnerLayerNorm()
)

My activations match wav2letter until the FC block inside the TDS (layer 2 in the outer sequential shown here), after which they diverge, so I'd guess my reorder/view step is wrong.

lunixbochs commented 4 years ago

I got the Linear to work by transposing the weights, but LayerNorm is definitely broken. The first couple of LayerNorms seem to match, using this as the module forward pass:

        tmp = input.permute(0, 3, 2, 1)
        shape = tmp.shape[2:]
        tmp = F.layer_norm(tmp, shape, self.weight.expand(shape), self.bias.expand(shape), self.eps)
        return tmp.permute(0, 3, 2, 1)

But after the first TDS finishes, the LayerNorm at the start of the next TDS block stops matching wav2letter. I implemented a basic LN 1 2 by hand and it doesn't seem to help:

class DimLayerNorm(nn.Module):
    def __init__(self, *dims: int, eps=1e-5):
        super().__init__()
        self.dims = tuple(dims)
        self.eps = eps
        self.weight = nn.Parameter(torch.Tensor(1))
        self.bias   = nn.Parameter(torch.Tensor(1))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.ones_(self.weight)
        nn.init.zeros_(self.weight)

    def forward(self, input):
        print('ln', input.shape, self.dims, self.weight, self.bias)
        mean = input.mean(self.dims, keepdim=True)
        std = input.var(self.dims, keepdim=True)
        out = (input - mean) / (std + self.eps).sqrt()
        return out * self.weight + self.bias

    def __repr__(self):
        return 'DimLayerNorm({})'.format(self.dims)

Do you have insight into what the flashlight layernorm (which calls BatchNorm and MKL internally) might be doing differently from this? Or is this close enough I'm probably failing to notice a problem somewhere else?

tlikhomanenko commented 4 years ago

But after the first TDS finishes, the LayerNorm at the start of the next TDS block stops matching wav2letter. I implemented a basic LN 1 2 by hand and it doesn't seem to help:

Ryan, could you post the sizes which you have in w2l and in pytorch before and the time they stop to match?

For example in the sota models we do LN 0 1 2 https://github.com/facebookresearch/wav2letter/blob/master/recipes/models/sota/2019/am_arch/am_tds_s2s.arch#L6.

Your by hand implementation look good for me. cc @vineelpratap to recheck this.

lunixbochs commented 4 years ago

I've been doing more investigation, I think the issue is somewhere besides LayerNorm if my hand-rolled version roughly matches. I think LayerNorm averaging the activations may be bringing a divergence somewhere in the middle to the edge where it's more visible, which is why I thought the LayerNorm was causing it. My next step is going to be diffing the activations layer by layer between w2l/pytorch layer by layer and seeing where it diverges, so I probably won't expect any help until I have that done.

vineelpratap commented 4 years ago

@lunixbochs - Here is a PyTorch implementation of TDS that you can refer to https://gist.github.com/vineelpratap/e9c030d488c5f2b804215c547d573932

Regarding LayerNorm, there are two things that you might want to take note:

  1. w2l uses scalar values for affine transformation.
  2. for streaming inference, we perform normalization only on feature axis. Section 3.1.2 in https://arxiv.org/pdf/2001.09727.pdf