daswer123 / xtts-api-server

A simple FastAPI Server to run XTTSv2
MIT License
292 stars 67 forks source link

AttributeError: module 'torch.amp' has no attribute 'GradScaler' #81

Open VermiNew opened 1 week ago

VermiNew commented 1 week ago

I encountered an AttributeError when running the XTTS-API-SERVER. The error message indicates that the torch.amp module does not have an attribute GradScaler. Below is the full traceback of the error:

Traceback (most recent call last):
  File "my_script.py", line 1, in <module>
    from xtts_api_server.server import app
  File "/opt/conda/lib/python3.10/site-packages/xtts_api_server/server.py", line 1, in <module>
    from TTS.api import TTS
  File "/opt/conda/lib/python3.10/site-packages/TTS/api.py", line 8, in <module>
    from TTS.config import load_config
  File "/opt/conda/lib/python3.10/site-packages/TTS/config/__init__.py", line 10, in <module>
    from TTS.config.shared_configs import *
  File "/opt/conda/lib/python3.10/site-packages/TTS/config/shared_configs.py", line 5, in <module>
    from trainer import TrainerConfig
  File "/opt/conda/lib/python3.10/site-packages/trainer/__init__.py", line 4, in <module>
    from trainer.model import *
  File "/opt/conda/lib/python3.10/site-packages/trainer/model.py", line 7, in <module>
    from trainer.trainer import Trainer
  File "/opt/conda/lib/python3.10/site-packages/trainer/trainer.py", line 63, in <module>
    class Trainer:
  File "/opt/conda/lib/python3.10/site-packages/trainer/trainer.py", line 947, in Trainer
    def _grad_clipping(self, grad_clip: float, optimizer: torch.optim.Optimizer, scaler: torch.amp.GradScaler):
AttributeError: module 'torch.amp' has no attribute 'GradScaler'

Steps to Reproduce:

  1. Install the XTTS-API-SERVER package and its dependencies.
  2. Import the app module from xtts_api_server.server.
  3. Run the script.

Expected Behavior: The server should start without any errors.

Actual Behavior: An AttributeError is raised, indicating that torch.amp does not have the GradScaler attribute.

Environment:

Additional Information: This error might be due to a mismatch between the PyTorch version and the usage of the GradScaler attribute in the code. In recent versions of PyTorch, GradScaler is located under torch.cuda.amp rather than torch.amp.

Potential solution: Update the import statement in trainer.py to:

from torch.cuda.amp import GradScaler

Thank you for looking into this issue!