Open ben-hayes opened 8 months ago
Hi,
Yeah in the end maybe always decorating pesto.predict
with torch.inference_mode
is a bit restrictive, I'll consider adding the possibility to choose between torch.no_grad
and torch.inference_mode
when running predict
, it should prevent such issues.
Also, I'm not sure why it only fails when using DDP. When training on a single GPU does it work as expected?
Training without DDP strategy is fine as there are no ops that modify the buffers. The bug occurs when DDP tries to sync buffers. It appears to be the call to torch.distributed._broadcast_coalesced
that is triggering an inplace modification:
Also just to say... this issue is a side effect of the lazy CQT init discussed in #19.
Context
In some use cases (e.g. DDSP audio synthesis) we want to perform F0 estimation on the GPU, so it's helpful to store PESTO as a submodule of our
pytorch_lightning.LightningModule
.Bug description
When training with the
DistributedDataParallel
strategy, the_sync_buffers
method causes the following exception to be thrown on the second training iteration, usingpesto.predict
:RuntimeError: Inplace update to inference tensor outside InferenceMode is not allowed.You can make a clone to get a normal tensor before doing inplace update.See https://github.com/pytorch/rfcs/pull/17 for more details.
Note that this persists whether the output is cloned or not — i.e. the problematic InferenceMode tensor is not the output.
Expected behavior
PESTO should be usable as a submodule.
Minimal example
Diagnostics
As far as I can tell, the issue arises because
data_processor.sampling_rate
is set insidepesto.predict
, which is decorated bytorch.inference_mode()
: https://github.com/SonyCSLParis/pesto/blob/afa44099640a2a9c41ef916a313ffae0e0890c85/pesto/core.py#L53This means that if the sample rate has changed, or is being set for the first time (as it is likely to be on the first call to
pesto.predict
), the CQT buffers (or parameters) are created as inference-mode tensors.Workaround
A temporary workaround is to set
DataProcessor.sampling_rate
before callingpesto.predict
.Possible solution
Use
with torch.inference_mode()
context manager around only the inference section ofpesto.predict
.