asteroid-team / asteroid

The PyTorch-based audio source separation toolkit for researchers
https://asteroid-team.github.io/
MIT License
2.28k stars 423 forks source link

RuntimeError: istft requires a complex-valued input tensor matching the output from stft with return_complex=True. #662

Open lucyanddarlin opened 1 year ago

lucyanddarlin commented 1 year ago

I ran /asteroid-master/egs/musdb18/X-UMX/run.sh, but got the error: RuntimeError: istft requires a complex-valued input tensor matching the output from stft with return_complex=True.

I try to set the return_complex=True in x_umx.py :

 stft_f = torch.stft(
            x,
            n_fft=self.n_fft,
            hop_length=self.n_hop,
            window=self.window,
            center=self.center,
            normalized=False,
            onesided=True,
            pad_mode="reflect",
            return_complex=True,
        )

but it didn't work...could someone tell me how to solve it? Thank u so much!

lucyanddarlin commented 1 year ago

here is the detail:

❯ /bin/zsh /Volumes/noEntry/study/asteroid-master/egs/musdb18/X-UMX/run.sh
Results from the following experiment will be stored in exp/train_xumx_d727eb8a
Stage 1: Training
101it [00:00, 705.80it/s]
0it [00:00, ?it/s]train_dataset <asteroid.data.musdb18_dataset.MUSDB18Dataset object at 0x14ea54820>
101it [00:00, 27558.20it/s]
valid_dataset <asteroid.data.musdb18_dataset.MUSDB18Dataset object at 0x14ea54a30>
Compute dataset statistics:   0%|                                                                                                                                                                         | 0/86 [00:00<?, ?it/s]/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/torch/functional.py:641: UserWarning: stft with return_complex=False is deprecated. In a future pytorch release, stft will return complex tensors for all inputs, and return_complex=False will raise an error.
Note: you can still call torch.view_as_real on the complex output to recover the old return format. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/SpectralOps.cpp:867.)
  return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore[attr-defined]
Compute dataset statistics: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 86/86 [01:34<00:00,  1.09s/it]
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/trainer/setup.py:201: UserWarning: MPS available but not used. Set `accelerator` and `devices` using `Trainer(accelerator='mps', devices=1)`.
  rank_zero_warn(
/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: UserWarning: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
  warning_cache.warn(
`Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=gloo
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

  | Name      | Type            | Params
----------------------------------------------
0 | model     | XUMX            | 35.6 M
1 | loss_func | MultiDomainLoss | 4.1 K 
----------------------------------------------
35.6 M    Trainable params
8.2 K     Non-trainable params
35.6 M    Total params
142.326   Total estimated model params size (MB)
Combination Loss: True
Multi Domain Loss: True, scaling parameter for time-domain loss=10.0
Sanity Checking: 0it [00:00, ?it/s]/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:430: PossibleUserWarning: The dataloader, val_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]Traceback (most recent call last):
  File "train.py", line 498, in <module>
    main(arg_dic, plain_args)
  File "train.py", line 465, in main
    trainer.fit(system)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 520, in fit
    call._call_and_handle_interrupt(
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 42, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 92, in launch
    return function(*args, **kwargs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 559, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 935, in _run
    results = self._run_stage()
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 976, in _run_stage
    self._run_sanity_check()
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1005, in _run_sanity_check
    val_loop.run()
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/loops/utilities.py", line 177, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 115, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 375, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_kwargs.values())
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 288, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/strategies/ddp.py", line 337, in validation_step
    return self.model(*args, **kwargs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1113, in _run_ddp_forward
    return module_to_run(*inputs, **kwargs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 102, in forward
    return self._forward_module.validation_step(*inputs, **kwargs)
  File "train.py", line 356, in validation_step
    loss_tmp += self.common_step(batch_tmp, batch_nb, train=False)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/asteroid/engine/system.py", line 101, in common_step
    est_targets = self(inputs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/asteroid/engine/system.py", line 73, in forward
    return self.model(*args, **kwargs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/asteroid/models/x_umx.py", line 169, in forward
    time_signals = self.decoder(spec, ang)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/asteroid/lib/python3.8/site-packages/asteroid/models/x_umx.py", line 401, in forward
    wav = torch.istft(
RuntimeError: istft requires a complex-valued input tensor matching the output from stft with return_complex=True.
DavidDiazGuerra commented 1 year ago

It seems like the newer versions of Pytorch have made some changes to the torch.stft and torch.istft functions. I've just run through the same issue and I think I could fix it by doing 'x = torch.view_as_complex(x)' just before calling torch.istft in the line that is raising the error.

Btw, you can also get rid of the deprecation warning you're getting by changing return_complex to True in the call to torch.stft and then doing stft_f = torch.view_as_real(stft_f) just after it.

r-sawata commented 8 months ago

Thank you so much, @DavidDiazGuerra!

As he said, this was caused by the mismatch between old and new pytorch versions. If the PR #684 will be accepted, then this problem should be resolved.

jrzzzz commented 4 months ago

Thank you so much!!! @DavidDiazGuerra