SonyCSLParis / pesto

Self-supervised learning for fast pitch estimation
GNU Lesser General Public License v3.0
168 stars 15 forks source link

InferenceMode causes RuntimeError when storing PESTO model and DataProcessor on LightningModule using DDP strategy #18

Open ben-hayes opened 8 months ago

ben-hayes commented 8 months ago

Context

In some use cases (e.g. DDSP audio synthesis) we want to perform F0 estimation on the GPU, so it's helpful to store PESTO as a submodule of our pytorch_lightning.LightningModule.

Bug description

When training with the DistributedDataParallel strategy, the _sync_buffers method causes the following exception to be thrown on the second training iteration, using pesto.predict:

RuntimeError: Inplace update to inference tensor outside InferenceMode is not allowed.You can make a clone to get a normal tensor before doing inplace update.See https://github.com/pytorch/rfcs/pull/17 for more details.

Note that this persists whether the output is cloned or not — i.e. the problematic InferenceMode tensor is not the output.

Expected behavior

PESTO should be usable as a submodule.

Minimal example

import pesto
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.demos.boring_classes import RandomDataset

class MyModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.f0_extractor = pesto.utils.load_model("mir-1k")
        self.prepocessor = pesto.utils.load_dataprocessor(1e-2, device="cuda")
        self.prepocessor.sampling_rate = 44100
        self.net = nn.Linear(201, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        _, f0_hz, _, _ = pesto.predict(
            x,
            44100,
            data_preprocessor=self.prepocessor,
            model=self.f0_extractor,
            convert_to_freq=True,
        )
        f0_hz = f0_hz.clone() # avoid in-place operation on InferenceMode output

        return self.net(f0_hz)

    def training_step(self, batch, batch_idx):
        x = batch
        y = self(x)
        loss = y.mean()
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

model = MyModel()
dataset = RandomDataset(88200, 64)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)

trainer = pl.Trainer(accelerator="gpu", max_epochs=100, strategy="ddp_find_unused_parameters_true")
trainer.fit(model, dataloader)

Diagnostics

As far as I can tell, the issue arises because data_processor.sampling_rate is set inside pesto.predict, which is decorated by torch.inference_mode(): https://github.com/SonyCSLParis/pesto/blob/afa44099640a2a9c41ef916a313ffae0e0890c85/pesto/core.py#L53

This means that if the sample rate has changed, or is being set for the first time (as it is likely to be on the first call to pesto.predict), the CQT buffers (or parameters) are created as inference-mode tensors.

Workaround

A temporary workaround is to set DataProcessor.sampling_rate before calling pesto.predict.

Possible solution

Use with torch.inference_mode() context manager around only the inference section of pesto.predict.

aRI0U commented 8 months ago

Hi,

Yeah in the end maybe always decorating pesto.predict with torch.inference_mode is a bit restrictive, I'll consider adding the possibility to choose between torch.no_grad and torch.inference_mode when running predict, it should prevent such issues.

Also, I'm not sure why it only fails when using DDP. When training on a single GPU does it work as expected?

ben-hayes commented 8 months ago

Training without DDP strategy is fine as there are no ops that modify the buffers. The bug occurs when DDP tries to sync buffers. It appears to be the call to torch.distributed._broadcast_coalesced that is triggering an inplace modification:

https://github.com/pytorch/pytorch/blob/b6a30bbfb6c1bcb9c785e7a853c2622c8bc17093/torch/nn/parallel/distributed.py#L1978-L1983

ben-hayes commented 8 months ago

Also just to say... this issue is a side effect of the lazy CQT init discussed in #19.