EveryVoiceTTS / EveryVoice

The EveryVoice TTS Toolkit - Text To Speech for your language
https://docs.everyvoice.ca
Other
21 stars 2 forks source link

Giving a vocoder checkpoint in lieu of a model checkpoint yields a very user unfriendly message #263

Closed joanise closed 1 month ago

joanise commented 9 months ago

Rookie mistake, I gave my vocoder_path to the model_path argument in synthesize. It took me a while to figure that out from this log:

everyvoice synthesize from-text \
  logs_and_checkpoints/VocoderExperiment/base/checkpoints/last.ckpt \
  -t "this is a test" \
  --vocoder-path  ~/u/EveryVoice/vocoder_paths/generator_universal.pth.tar
Loading checkpoint from logs_and_checkpoints/VocoderExperiment/base/checkpoints/last.ckpt
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /gpfs/fs3c/nrc/ict/portage/u/joa125/EveryVoice/EveryVoice/everyvoice/model/feature_prediction/Fa │
│ stSpeech2_lightning/fs2/cli/synthesize.py:240 in synthesize                                      │
│                                                                                                  │
│   237 │   device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")                │
│   238 │   # Load checkpoints                                                                     │
│   239 │   print(f"Loading checkpoint from {model_path}", file=sys.stderr)                        │
│ ❱ 240 │   model: FastSpeech2 = FastSpeech2.load_from_checkpoint(model_path).to(device)           │
│   241 │   model.eval()                                                                           │
│   242 │   # output to .wav will require a valid spec-to-wav model                                │
│   243 │   if SynthesizeOutputFormats.wav in output_type:                                         │
│                                                                                                  │
│ /home/joa125/u/miniconda3/envs/EV-test/lib/python3.10/site-packages/pytorch_lightning/utilities/ │
│ model_helpers.py:125 in wrapper                                                                  │
│                                                                                                  │
│   122 │   │   │   │   │   f"The classmethod `{cls.__name__}.{self.method.__name__}` cannot be    │
│   123 │   │   │   │   │   " Please call it on the class type and make sure the return value is   │
│   124 │   │   │   │   )                                                                          │
│ ❱ 125 │   │   │   return self.method(cls, *args, **kwargs)                                       │
│   126 │   │                                                                                      │
│   127 │   │   return wrapper                                                                     │
│   128                                                                                            │
│                                                                                                  │
│ /home/joa125/u/miniconda3/envs/EV-test/lib/python3.10/site-packages/pytorch_lightning/core/modul │
│ e.py:1581 in load_from_checkpoint                                                                │
│                                                                                                  │
│   1578 │   │   │   y_hat = pretrained_model(x)                                                   │
│   1579 │   │                                                                                     │
│   1580 │   │   """                                                                               │
│ ❱ 1581 │   │   loaded = _load_from_checkpoint(                                                   │
│   1582 │   │   │   cls,  # type: ignore[arg-type]                                                │
│   1583 │   │   │   checkpoint_path,                                                              │
│   1584 │   │   │   map_location,                                                                 │
│                                                                                                  │
│ /home/joa125/u/miniconda3/envs/EV-test/lib/python3.10/site-packages/pytorch_lightning/core/savin │
│ g.py:91 in _load_from_checkpoint                                                                 │
│                                                                                                  │
│    88 │   if issubclass(cls, pl.LightningDataModule):                                            │
│    89 │   │   return _load_state(cls, checkpoint, **kwargs)                                      │
│    90 │   if issubclass(cls, pl.LightningModule):                                                │
│ ❱  91 │   │   model = _load_state(cls, checkpoint, strict=strict, **kwargs)                      │
│    92 │   │   state_dict = checkpoint["state_dict"]                                              │
│    93 │   │   if not state_dict:                                                                 │
│    94 │   │   │   rank_zero_warn(f"The state dict in {checkpoint_path!r} contains no parameter   │
│                                                                                                  │
│ /home/joa125/u/miniconda3/envs/EV-test/lib/python3.10/site-packages/pytorch_lightning/core/savin │
│ g.py:158 in _load_state                                                                          │
│                                                                                                  │
│   155 │   │   # filter kwargs according to class init unless it allows any argument via kwargs   │
│   156 │   │   _cls_kwargs = {k: v for k, v in _cls_kwargs.items() if k in cls_init_args_name}    │
│   157 │                                                                                          │
│ ❱ 158 │   obj = cls(**_cls_kwargs)                                                               │
│   159 │                                                                                          │
│   160 │   if isinstance(obj, pl.LightningDataModule):                                            │
│   161 │   │   if obj.__class__.__qualname__ in checkpoint:                                       │
│                                                                                                  │
│ /gpfs/fs3c/nrc/ict/portage/u/joa125/EveryVoice/EveryVoice/everyvoice/model/feature_prediction/Fa │
│ stSpeech2_lightning/fs2/model.py:40 in __init__                                                  │
│                                                                                                  │
│    37 │   │   """ """                                                                            │
│    38 │   │   super().__init__()                                                                 │
│    39 │   │   if not isinstance(config, FastSpeech2Config):                                      │
│ ❱  40 │   │   │   config = FastSpeech2Config(**config)                                           │
│    41 │   │   if stats is not None and not isinstance(stats, Stats):                             │
│    42 │   │   │   stats = Stats(**stats)                                                         │
│    43 │   │   self.config = config                                                               │
│                                                                                                  │
│ /gpfs/fs3c/nrc/ict/portage/u/joa125/EveryVoice/EveryVoice/everyvoice/config/shared_types.py:112  │
│ in __init__                                                                                      │
│                                                                                                  │
│   109 │                                                                                          │
│   110 │   # [Using validation context with BaseModel initialization](https://docs.pydantic.dev   │
│   111 │   def __init__(__pydantic_self__, **data: Any) -> None:                                  │
│ ❱ 112 │   │   __pydantic_self__.__pydantic_validator__.validate_python(                          │
│   113 │   │   │   data,                                                                          │
│   114 │   │   │   self_instance=__pydantic_self__,                                               │
│   115 │   │   │   context=_init_context_var.get(),                                               │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ValidationError: 14 validation errors for FastSpeech2Config
model.resblock
  Extra inputs are not permitted [type=extra_forbidden, input_value='1', input_type=str]
    For further information visit https://errors.pydantic.dev/2.6/v/extra_forbidden
model.upsample_rates
  Extra inputs are not permitted [type=extra_forbidden, input_value=[8, 8], input_type=list]
    For further information visit https://errors.pydantic.dev/2.6/v/extra_forbidden
model.upsample_kernel_sizes
  Extra inputs are not permitted [type=extra_forbidden, input_value=[16, 16], input_type=list]
    For further information visit https://errors.pydantic.dev/2.6/v/extra_forbidden
model.upsample_initial_channel
  Extra inputs are not permitted [type=extra_forbidden, input_value=512, input_type=int]
    For further information visit https://errors.pydantic.dev/2.6/v/extra_forbidden
model.resblock_kernel_sizes
  Extra inputs are not permitted [type=extra_forbidden, input_value=[3, 7, 11], input_type=list]
    For further information visit https://errors.pydantic.dev/2.6/v/extra_forbidden
model.resblock_dilation_sizes
  Extra inputs are not permitted [type=extra_forbidden, input_value=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], input_type=list]
    For further information visit https://errors.pydantic.dev/2.6/v/extra_forbidden
model.activation_function
  Extra inputs are not permitted [type=extra_forbidden, input_value='everyvoice.utils.original_hifigan_leaky_relu', input_type=str]
    For further information visit https://errors.pydantic.dev/2.6/v/extra_forbidden
model.istft_layer
  Extra inputs are not permitted [type=extra_forbidden, input_value=True, input_type=bool]
    For further information visit https://errors.pydantic.dev/2.6/v/extra_forbidden
model.msd_layers
  Extra inputs are not permitted [type=extra_forbidden, input_value=3, input_type=int]
    For further information visit https://errors.pydantic.dev/2.6/v/extra_forbidden
model.mpd_layers
  Extra inputs are not permitted [type=extra_forbidden, input_value=[2, 3, 5, 7, 11], input_type=list]
    For further information visit https://errors.pydantic.dev/2.6/v/extra_forbidden
training.generator_warmup_steps
  Extra inputs are not permitted [type=extra_forbidden, input_value=0, input_type=int]
    For further information visit https://errors.pydantic.dev/2.6/v/extra_forbidden
training.gan_type
  Extra inputs are not permitted [type=extra_forbidden, input_value='original', input_type=str]
    For further information visit https://errors.pydantic.dev/2.6/v/extra_forbidden
training.wgan_clip_value
  Extra inputs are not permitted [type=extra_forbidden, input_value=0.01, input_type=float]
    For further information visit https://errors.pydantic.dev/2.6/v/extra_forbidden
training.finetune
  Extra inputs are not permitted [type=extra_forbidden, input_value=False, input_type=bool]
    For further information visit https://errors.pydantic.dev/2.6/v/extra_forbidden

An other example command:

everyvoice synthesize from-text \
  logs_and_checkpoints/FeaturePredictionExperiment/base/checkpoints/last.ckpt \
  -t "this is a test" \
  --vocoder-path logs_and_checkpoints/FeaturePredictionExperiment/base/checkpoints/last.ckpt

A little bit of friendlier messaging could help the user a lot.

SamuelLarkin commented 8 months ago

relates to #114

joanise commented 8 months ago

Sure enough, it's the same problem with a different kind of file (checkpoint vs config).

We should have some kind of version number of magic number identifying each type of file we generate/support, and a quick check of that before the Pydantic checking even starts.

roedoejet commented 2 months ago

Sure enough, it's the same problem with a different kind of file (checkpoint vs config).

We should have some kind of version number of magic number identifying each type of file we generate/support, and a quick check of that before the Pydantic checking even starts.

good idea - bumping this up so that we solve it pre-release of the checkpoints

SamuelLarkin commented 1 month ago

NOTES

First attempt, save the type and a version number when calling FastSpeech2.on_save_checkpoint() then use FastSpeech2.on_load_checkpoint(checkpoint) and make sure the version and the model's type match what we are expecting. There's a problem with that approach, pytorch lightning actually instantiates a FastSpeech2 variable which tries to load the wrong config type during __init__() than pydantic raises and exception, obviously because a lot of fields are wrong. Looking at pytorch lightning's code, we actually see that it creates the class instance and later tries to call on_load_checkpoint(). In the function def _load_state() here https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/core/saving.py#L117 https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/core/saving.py#L165

obj = instantiator(cls, _cls_kwargs) if instantiator else cls(**_cls_kwargs)

https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/core/saving.py#L184

obj.on_load_checkpoint(checkpoint)

Stack Trace

Traceback (most recent call last):
  File "/fs/hestia_Hnrc/ict/sam037/git/EveryVoice/everyvoice/tests/test_model.py", line 212, in test_wrong_model_type
    FastSpeech2.load_from_checkpoint(ckpt_fn)
  File "/home/sam037/.conda/envs/EveryVoice.sl/lib/python3.10/site-packages/pytorch_lightning/utilities/model_helpers.py", line 125, in wrapper
    return self.method(cls, *args, **kwargs)
  File "/home/sam037/.conda/envs/EveryVoice.sl/lib/python3.10/site-packages/pytorch_lightning/core/module.py", line 1582, in load_from_checkpoint
    loaded = _load_from_checkpoint(
  File "/home/sam037/.conda/envs/EveryVoice.sl/lib/python3.10/site-packages/pytorch_lightning/core/saving.py", line 91, in _load_from_checkpoint
    model = _load_state(cls, checkpoint, strict=strict, **kwargs)
  File "/home/sam037/.conda/envs/EveryVoice.sl/lib/python3.10/site-packages/pytorch_lightning/core/saving.py", line 165, in _load_state
    obj = instantiator(cls, _cls_kwargs) if instantiator else cls(**_cls_kwargs)
  File "/fs/hestia_Hnrc/ict/sam037/git/EveryVoice/everyvoice/model/feature_prediction/FastSpeech2_lightning/fs2/model.py", line 48, in __init__
    config = FastSpeech2Config(**config)
  File "/fs/hestia_Hnrc/ict/sam037/git/EveryVoice/everyvoice/config/shared_types.py", line 128, in __init__
    __pydantic_self__.__pydantic_validator__.validate_python(