jishengpeng / WavTokenizer

SOTA discrete acoustic codec models with 40 tokens per second for audio language modeling
MIT License
833 stars 46 forks source link

question about streaming infer #56

Open VJJJJJJ1 opened 6 days ago

VJJJJJJ1 commented 6 days ago

hi, I am trying to implement a streaming WavTokenizer. I set causal = True in encoder without other modification, and replace all nn.Conv1d in the decoder with SConv1d. For example, in WavTokenizer/decoder/modules.py, I changed self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) to self.dwconv = SConv1d(dim, dim, kernel_size=7, groups=dim, causal=True). In the AttenBlock, after multiplying q and k, I add a mask matrix as follows:

# compute attention
b, c, h = q.shape
q = q.permute(0, 2, 1)  # b,hw,c
w_ = torch.bmm(q, k)  # b,hw,hw    w_[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c) ** (-0.5))

# Apply causal mask
mask = torch.tril(torch.ones(h, h)).to(w_.device)  # mask matrix
w_ = w_.masked_fill(mask == 0, float('-inf'))      # Set mask to -inf
w_ = torch.nn.functional.softmax(w_, dim=2)

Is my modification correct? Unfortunately, during the experiment, distortion appeared at the end of the audio.

thank you for your reply!

keepingitneil commented 6 days ago

@VJJJJJJ1 I'm working on the same thing - want to collaborate?