tky823 / DNN-based_source_separation

A PyTorch implementation of DNN-based source separation.
286 stars 50 forks source link

change the meaning of the parameter `H` #24

Closed tky823 closed 3 years ago

tky823 commented 3 years ago

So far, H means the number of total hidden channels in nn.LSTM.

class Model(nn.Module):
    def __init__(self, in_channels, hidden_channels, causal=False):
        if causal:
            num_directions = 2
            bidirectional = True
        else:
            num_directions = 1
            bidirectional = False

        self.rnn = nn.LSTM(in_channels, hidden_channels//num_directions, bidirectional=bidirectional)
        self.fc = nn.Linear(hidden_channels, hidden_channels, bidirectional=bidirectional)

    def forward(self, input):
        x, (_, _) = self.rnn(input)
        output  = self.fc = self.fc(x)

        return output

From now on, H means the number of directions in each direction.

class Model(nn.Module):
    def __init__(self, in_channels, hidden_channels, causal=False):
        if causal:
            num_directions = 2
            bidirectional = True
        else:
            num_directions = 1
            bidirectional = False

        self.rnn = nn.LSTM(in_channels, hidden_channels, bidirectional=bidirectional)
        self.fc = nn.Linear(num_directions*hidden_channels, hidden_channels, bidirectional=bidirectional)

    def forward(self, input):
        x, (_, _) = self.rnn(input)
        output  = self.fc = self.fc(x)

        return output

where H is used as follows

    model = Model(C, H, causal=False)