sdatkinson / neural-amp-modeler

Neural network emulator for guitar amplifiers.
MIT License
1.78k stars 136 forks source link

Improve MRSTFT loss on MPS #453

Open sdatkinson opened 1 month ago

sdatkinson commented 1 month ago

Prerequisite to #436. Try to make MRSTFT at least run partially on MPS to minimize the speed hit.

sdatkinson commented 1 month ago

It looks like the fix is to get this implemented in PyTorch. The exception raised is:

The operator 'aten::angle' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

I tried this, and it worked but there was no time savings--having this one op on CPU seems to fully account for the overhead associated with the MRSTFT loss. If it's not addressed, then no other work on this is worth it.

Here are the comments to +1!