sdatkinson / neural-amp-modeler

Neural network emulator for guitar amplifiers.
MIT License
1.87k stars 150 forks source link

[BUG] False positive error using Conv1d on MPS `Output channels > 65536 not supported at the MPS device.` #505

Closed sdatkinson closed 3 days ago

sdatkinson commented 3 days ago

Describe the bug Training locally on my MBP with macOS 15.1 I see the following error:

  | Name | Type    | Params | Mode 
-----------------------------------------
0 | _net | WaveNet | 13.8 K | train
-----------------------------------------
13.8 K    Trainable params
0         Non-trainable params
13.8 K    Total params
0.055     Total estimated model params size (MB)
111       Modules in train mode
0         Modules in eval mode
Sanity Checking DataLoader 0:   0%|                       | 0/1 [00:00<?, ?it/s]Exception in Tkinter callback
Traceback (most recent call last):
  File "/Users/steve/opt/anaconda3/envs/nam/lib/python3.10/tkinter/__init__.py", line 1921, in __call__
    return self.func(*args)
  File "/Users/steve/src/neural-amp-modeler/nam/train/gui/__init__.py", line 684, in _train
    self._train2()
  File "/Users/steve/src/neural-amp-modeler/nam/train/gui/__init__.py", line 704, in _train2
    train_output = core.train(
  File "/Users/steve/src/neural-amp-modeler/nam/train/core.py", line 1447, in train
    trainer.fit(model, train_dataloader, val_dataloader)
  File "/Users/steve/opt/anaconda3/envs/nam/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 538, in fit
    call._call_and_handle_interrupt(
  File "/Users/steve/opt/anaconda3/envs/nam/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 47, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/Users/steve/opt/anaconda3/envs/nam/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 574, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/Users/steve/opt/anaconda3/envs/nam/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 981, in _run
    results = self._run_stage()
  File "/Users/steve/opt/anaconda3/envs/nam/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1023, in _run_stage
    self._run_sanity_check()
  File "/Users/steve/opt/anaconda3/envs/nam/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1052, in _run_sanity_check
    val_loop.run()
  File "/Users/steve/opt/anaconda3/envs/nam/lib/python3.10/site-packages/pytorch_lightning/loops/utilities.py", line 178, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/Users/steve/opt/anaconda3/envs/nam/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 135, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/Users/steve/opt/anaconda3/envs/nam/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 396, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
  File "/Users/steve/opt/anaconda3/envs/nam/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 319, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/Users/steve/opt/anaconda3/envs/nam/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 411, in validation_step
    return self.lightning_module.validation_step(*args, **kwargs)
  File "/Users/steve/src/neural-amp-modeler/nam/models/base.py", line 311, in validation_step
    preds, targets, loss_dict = self._shared_step(batch)
  File "/Users/steve/src/neural-amp-modeler/nam/models/base.py", line 254, in _shared_step
    preds = self(*args, pad_start=False)
  File "/Users/steve/opt/anaconda3/envs/nam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/steve/opt/anaconda3/envs/nam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/steve/src/neural-amp-modeler/nam/models/base.py", line 234, in forward
    return self.net(*args, **kwargs)  # TODO deprecate--use self.net() instead.
  File "/Users/steve/opt/anaconda3/envs/nam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/steve/opt/anaconda3/envs/nam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/steve/src/neural-amp-modeler/nam/models/_base.py", line 182, in forward
    y = self._forward(x, **kwargs)
  File "/Users/steve/src/neural-amp-modeler/nam/models/wavenet.py", line 434, in _forward
    y = self._net(x)
  File "/Users/steve/opt/anaconda3/envs/nam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/steve/opt/anaconda3/envs/nam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/steve/src/neural-amp-modeler/nam/models/wavenet.py", line 336, in forward
    head_input, y = layer(y, x, head_input=head_input)
  File "/Users/steve/opt/anaconda3/envs/nam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/steve/opt/anaconda3/envs/nam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/steve/src/neural-amp-modeler/nam/models/wavenet.py", line 220, in forward
    x = self._rechannel(x)
  File "/Users/steve/opt/anaconda3/envs/nam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/steve/opt/anaconda3/envs/nam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/steve/opt/anaconda3/envs/nam/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 375, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/Users/steve/opt/anaconda3/envs/nam/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 370, in _conv_forward
    return F.conv1d(
NotImplementedError: Output channels > 65536 not supported at the MPS device. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

Tracing it down, I'm not using "> 65536" output channels. So I know there's a bug (https://github.com/pytorch/pytorch/issues/129207) but this NotImplementedError seems to be reaching too far.

To Reproduce Steps to reproduce the behavior:

  1. Install via environment_cpu.yml
  2. Verify PyTorch v2.5.0 or 2.5.1 is installed.
  3. nam, pick files, start training.
  4. Error above

Screenshots N/A

Desktop (please complete the following information):

Additional context Rolling PyTorch back to v2.4.1 (<2.5.0) appears to resolve the issue. Also reported in the FB group.

sdatkinson commented 3 days ago

It seems that the error is phrased incorrectly--this refers to the sequence length, not the number of output channels.

So another workaround would be to process the output in chunks, but that sounds dreadful. Haven't checked the speed yet. Possibly a try/catch?... I can do that.