Neutone / neutone_sdk

Join the community on Discord for more discussions around Neutone! https://discord.gg/VHSMzb8Wqp
GNU Lesser General Public License v2.1
465 stars 21 forks source link

[cm] Conv1dGeneral implementation #64

Closed christhetree closed 7 months ago

christhetree commented 8 months ago

This PR adds a Conv1D implementation that supports causal convolutions and cached convolutions. The convolution can be toggled to be cached or not at any point in time via the method set_cached(). Behaves identically to torch.nn.Conv1d when not in cached mode and causal is False. When causal is True, the convolution is padded on the left side only. When cached and not causal, the convolution delays the output by get_delay_samples() samples. It's also TorchScript compatible.

This PR also includes tests to ensure the Conv1d is behaving as expected for all padding types and in cached and / or causal modes of operation. The only aspect that may not yet be compatible are strides greater than 1, this can be added in a future PR.

The code is roughly based off @francescopapaleo 's Conv1dSwitching code and @hyakuchiki 's streaming convolution code. A future PR will refactor the TCN library in the SDK and will swap out all conv1d's to use this implementation instead.

hyakuchiki commented 7 months ago

At least when the stride is 1, everything looks alright to me. If the stride is greater than 1, the padding can be implemented like this (probably inefficient)

  def get_n_frames(self, input_length: int) -> float:
      # data with size L allows for (L-K)//s + 1 conv ops
      return float((input_length - self.K) / self.stride[0] + 1.0)

  def cached_pad(self, x: torch.Tensor) -> torch.Tensor:
      """
      self.cache = x[..., -padding:] doesn't work for non-"same" type convs
      This keeps track of where the convolution was performed last
      so that we can keep necessary samples to do next convolution with next buffer
      """
      # x: batch, channel, L
      x = torch.cat([self.cache[: x.shape[0]], x], dim=-1)
      n_convs = math.floor(self.get_n_frames(x.shape[-1]))
      # starting position of conv that wasn't calculated
      next_pos = self.stride[0] * n_convs
      if next_pos < x.shape[-1]:
          # save as new cache
          self.cache = x[..., next_pos:].detach()
      else:
          # There is nothing that can be cached: shouldn't happen unless stride is larger than n
          self.cache = torch.empty(
              x.shape[0], self.in_channels, 0, device=self.cache.device
          )
      return x
bogdanteleaga commented 7 months ago

LGTM, left a few small comments. Thank you for adding all the tests as well!