lucidrains / BS-RoFormer

Implementation of Band Split Roformer, SOTA Attention network for music source separation out of ByteDance AI Labs
MIT License
384 stars 13 forks source link

[bad assertion] strange bottleneck performance #9

Closed dorpxam closed 10 months ago

dorpxam commented 10 months ago

Hi, let me introduce this little benchmark of the model running on GPU (device='cuda'):

---------- Training ----------
Time: 0.019369 seconds in "STFT" stage.
Time: 0.704290 seconds in "Band Split" stage.
Time: 0.124404 seconds in "Transformer" stage.
Time: 0.017830 seconds in "Mask Estimator" stage.
Time: 6.671448 seconds in "ISTFT" stage.
--------------------------------
Time: 0.147087 seconds in "Multiresolution Loss" stage.
Time: 0.493661 seconds in "Loss Backward" stage.
--------------------------------
Time: 9.102034 seconds for the whole process.

---------- Evaluation ----------
Time: 0.000107 seconds in "STFT" stage.
Time: 0.003609 seconds in "Band Split" stage.
Time: 0.015426 seconds in "Transformer" stage.
Time: 0.009862 seconds in "Mask Estimator" stage.
Time: 3.528062 seconds in "ISTFT" stage.
--------------------------------
Time: 3.565335 seconds for the whole process.

The model specifications follow the original paper:

    model = BSRoformer(dim = 384,
                       depth = 12,
                       time_transformer_depth = 1,
                       freq_transformer_depth = 1,
                       mask_estimator_depth = 4).to(device)

The tensors for testing are initialized using:

    audio_size = 4 * 44100 # 4 seconds @ 44.1 KHz

    sample = torch.randn(2, audio_size, dtype=torch.float32).to(device)
    target = torch.randn(2, audio_size, dtype=torch.float32).to(device)

The benchmark do not include the einops operations, nor the other tensor manipulation but bounds the stages of the model like that:

        self.bench.start('Band Split')
        x = self.band_split(x)
        self.bench.stop()

        # axial / hierarchical attention

        self.bench.start('Transformer')
        for time_transformer, freq_transformer in self.layers:

            x = rearrange(x, 'b t f d -> b f t d')
            x, ps = pack([x], '* t d')

            x = time_transformer(x)

            x, = unpack(x, ps, '* t d')
            x = rearrange(x, 'b f t d -> b t f d')
            x, ps = pack([x], '* f d')

            x = freq_transformer(x)

            x, = unpack(x, ps, '* f d')

        x = self.final_norm(x)
        self.bench.stop()

        num_stems = len(self.mask_estimators)

        self.bench.start('Mask Estimator')
        mask = torch.stack([fn(x) for fn in self.mask_estimators], dim = 1)
        self.bench.stop()

The Benchmark class is a pretty trivial one:

from time import perf_counter

class Benchmark():
    def __init__(self):
        pass

    def start(self, stage):
        self.stage = stage
        self.time_start = perf_counter()

    def stop(self):
        self.time_duration = perf_counter() - self.time_start
        print(f'Time: {self.time_duration:.6f} seconds in "{self.stage}" stage.')

Conclusion:

66% of the time in the model is lost in the torch.istft process while the torch.stft is not slow at all.

Am I the only one to notice this?

Edit:

Wrong conclusion, see next message.

dorpxam commented 10 months ago

Sorry for that, this is a misunderstanding of the GPU internal work. By calling a GPU synchronization on each stage, I got a better and coherent benchmark report ! This issue can be close.

-------------------- Training --------------------
Time: 0.022357 seconds in "STFT" stage.
Time: 0.624558 seconds in "Band Split" stage.
Time: 12.844367 seconds in "Transformers" stage.
Time: 0.244779 seconds in "Mask Estimators" stage.
Time: 0.006015 seconds in "ISTFT" stage.
--------------------------------------------------
Time: 0.290139 seconds in "Loss Function" stage.
Time: 50.656937 seconds in "Loss Backward" stage.
--------------------------------------------------
Time: 65.572889 seconds for the whole process.