asteroid-team / asteroid-filterbanks

Asteroid's filterbanks :rocket:
https://asteroid-team.github.io/
MIT License
80 stars 20 forks source link

Add tests for STFTFB against torch.stft for #1

Closed faroit closed 3 years ago

faroit commented 3 years ago

for models that have been already trained using torch.stft it would be nice they could swap with STFTFB.

faroit commented 3 years ago

@Baldwin-disso worked on this already. Can you share your progress here so we can continue to work on this together?

mpariente commented 3 years ago

That would be great indeed I think we can use the new hooks for that !

mpariente commented 3 years ago

I guess, those are the interesting snippets

stft

Tensor stft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt,
            const optional<int64_t> win_lengthOpt, const Tensor& window,
            const bool normalized, const bool onesided) {
...
  auto window_ = window;
  if (win_length < n_fft) {
    // pad center
    window_ = at::zeros({n_fft}, self.options());
    auto left = (n_fft - win_length) / 2;
    if (window.defined()) {
      window_.narrow(0, left, win_length).copy_(window);
    } else {
      window_.narrow(0, left, win_length).fill_(1);
    }
  }
  int64_t n_frames = 1 + (len - n_fft) / hop_length;
  // time2col
  input = input.as_strided(
    {batch, n_frames, n_fft},
    {input.stride(0), hop_length * input.stride(1), input.stride(1)}
  );
  if (window_.defined()) {
    input = input.mul(window_);
  }
  // rfft and transpose to get (batch x fft_size x num_frames)
  auto out = input.rfft(1, normalized, onesided).transpose_(1, 2);
  if (self.dim() == 1) {
    return out.squeeze_(0);
  } else {
    return out;
  }
}

where we need to match this line input.rfft(1, normalized, onesided), which shouldn't be too hard.

and istft

Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt,
             const optional<int64_t> win_lengthOpt, const Tensor& window,
             const bool center, const bool normalized, const bool onesided,
             const optional<int64_t> lengthOpt) {
...
  Tensor window_tmp = window.defined() ? window : at::ones({win_length,}, options);
  if (win_length != n_fft) {
    // center window by padding zeros on right and left side
    int64_t left = (n_fft - win_length) / 2;
    window_tmp = at::constant_pad_nd(window_tmp, {left, n_fft - win_length - left}, 0);
    TORCH_INTERNAL_ASSERT(window_tmp.size(0) == n_fft);
  }

  Tensor input = self;
  if (input_dim == 3) {
    input = input.unsqueeze(0);
  }

  input = input.transpose(1, 2);  // size: (channel, n_frames, fft_size, 2)
  input = at::native::irfft(input, 1, normalized, onesided, {n_fft, });  // size: (channel, n_frames, n_fft)
  TORCH_INTERNAL_ASSERT(input.size(2) == n_fft);

  Tensor y_tmp = input * window_tmp.view({1, 1, n_fft});  // size: (channel, n_frames, n_fft)
  y_tmp = y_tmp.transpose(1, 2);  // size: (channel, n_fft, frame)

  const Tensor eye = at::native::eye(n_fft, options).unsqueeze(1);
  Tensor y = at::conv_transpose1d(y_tmp, eye,
                                  /*bias*/ Tensor(),
                                  /*stride*/ {hop_length,},
                                  /*padding*/{0,});  // size: (channel, n_frames, n_fft)
  window_tmp = window_tmp.pow(2).view({n_fft, 1}).repeat({1, n_frames}).unsqueeze(0);  // size: (1, n_fft, n_frames)
  Tensor window_envelop = at::conv_transpose1d(window_tmp, eye,
                                               /*bias*/ Tensor(),
                                               /*stride*/ {hop_length, },
                                               /*padding*/{0, });  // size: (1, 1, expected_output_signal_len)
  TORCH_INTERNAL_ASSERT(expected_output_signal_len == y.size(2));
  TORCH_INTERNAL_ASSERT(expected_output_signal_len == window_envelop.size(2));

  // We need to trim the front padding away if centered
  const auto start = center ? n_fft / 2 : 0;
  const auto end = lengthOpt.has_value()? start + lengthOpt.value() : - n_fft / 2;

  y = y.slice(2, start, end, 1);
  window_envelop = window_envelop.slice(2, start, end, 1);
  const auto window_envelop_lowest = window_envelop.abs().min().item().toDouble();
  if (window_envelop_lowest < 1e-11) {
    std::ostringstream ss;
    REPR(ss) << "window overlap add min: " << window_envelop_lowest;
    AT_ERROR(ss.str());
  }

  y = (y / window_envelop).squeeze(1);  // size: (channel, expected_output_signal_len)
  if (input_dim == 3) {
    y = y.squeeze(0);
  }
  return y;

where more things are happening, and in particular, my favorite bit, the division by the OLAed squared window :sweat_smile:

mpariente commented 3 years ago

BTW, this is the old version, back in July (in case, that's the commit where I'm checking : 9a3e16c773496b16e6c02f6e3e020be5bb485ea0). But I'm pretty sure today version is compatible with the old one, and is easier to read.