RamiMatar / Chroma-BSRNN

11 stars 1 forks source link

eval overlap smoothing #3

Open deyituo opened 1 year ago

deyituo commented 1 year ago

It seems like that it should be better to smooth the samples in eval.py like this:

    def forward_overlap(self, mix, segment_samples, overlap=0.75):
        """forward chunk by chunk, overlap with smooth

        :param mix: (C, S)
        :param segment_samples:
        :param overlap:
        :return:  (C, S)
        """
        #start_time = time.time()
        chunk_size = segment_samples
        mix = mix.unsqueeze(0)
        # mix: (1, C, S)
        step = int(chunk_size * (1 - overlap))
        #print('Initial shape: {} Chunk size: {} Step: {}'.format(mix.shape, chunk_size, step))
        result = torch.zeros(mix.shape, dtype=mix.dtype, device=mix.device)
        divider = torch.zeros(mix.shape, dtype=mix.dtype, device=mix.device)
        total = 0
        chunk_num = int(mix.shape[-1] / step)
        for i in range(chunk_num):
            total += 1
            start = i * step
            if i == (chunk_num - 1):
                end = mix.shape[-1]
            else:
                end = start + chunk_size
            #end = min(start + chunk_size, mix.shape[-1])
            # print('Chunk: {} Start: {} End: {}'.format(total, start, end))
            mix_part = mix[:, :, start:end]
            # mix_part: (B, C, chunk_samples)
            sep_part = self.model(mix_part)
            # sep_part: (B, C, chunk_samples)
            result[:, :, start:end] += sep_part
            divider[:, :, start:end] += 1
        sources = result / divider
        sources = sources.squeeze(0)
        # sources: (C, S)
        #print('Final shape: {} Overall time: {:.2f}'.format(sources.shape, time.time() - start_time))
        return sources
RamiMatar commented 1 year ago

That’s an interesting idea, I think it could improve performance on the edges of the song, I’ll give it a try soon and let you know!