jaywalnut310 / glow-tts

A Generative Flow for Text-to-Speech via Monotonic Alignment Search
MIT License
667 stars 150 forks source link

Size of glowtts models #41

Closed Zarbuvit closed 4 years ago

Zarbuvit commented 4 years ago

Hi all,
I am wondering as to the sizes of other peoples saved models and of ways to reduce them.
My saved models are .pth and are around 328MB. I have looked at the models by MozillaTTS (based on this repo as I understand) which are .pth.tar files and are 288MB. @echelon has also shown that he saves his models as .torchjit and they are 110MB.

I was wondering if I was doing something wrong that leads me to getting such large model sizes or if it is normal. Ultimately I would like to make my model sizes smaller and wanted to see if anyone had any ideas.

I'm looking into pytorch Dynamic Quantization (https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html) but I assume there is a reason why I didn't see anyone here mention it and that it isn't used by @jaywalnut310 in the first place.

Thanks for your time and help

JosefJoubert commented 4 years ago

Hi

The reason for the size difference is much more simple. The model checkpoints contain additional data that helps it with training. If you're saving the model to be used for inference, then you don't need this data anymore. You can just extract the data you need from the checkpoint like this:

import torch large_file = torch.load('330_MB_file.pth') large_file.keys() dict_keys(['model', 'iteration', 'optimizer', 'learning_rate']) smaller_file = {'model':large_file['model']} torch.save(smaller_file,'110_MB_file.pth')

Or simply save the state_dict of the model to a separate file.

Zarbuvit commented 4 years ago

Thank you so much @JosefJoubert ! I ended up adding a function to utils for saving the final model.

def save_final_model(model, final_model_path):
    if hasattr(model, 'module'):
         state_dict = model.module.state_dict()
    else:
        state_dict = model.state_dict()
    torch.save({'model': state_dict}, final_model_path)

and calling it from train.py once the number of epochs was the same as in the config file. and for inference to work I added the function:

def load_final_model(final_model_path, model):
  assert os.path.isfile(final_model_path)
  final_model_dict = torch.load(final_model_path, map_location='cpu')
  saved_state_dict = final_model_dict['model']
  if hasattr(model, 'module'):
    state_dict = model.module.state_dict()
  else:
    state_dict = model.state_dict()
  new_state_dict= {}
  for k, v in state_dict.items():
    try:
      new_state_dict[k] = saved_state_dict[k]
    except:
      logger.info("%s is not in the checkpoint" % k)
      new_state_dict[k] = v
  if hasattr(model, 'module'):
    model.module.load_state_dict(new_state_dict)
  else:
    model.load_state_dict(new_state_dict)
  logger.info("Loaded final model '{}'" .format(
    final_model_path))
  return model

and called it when loading the model in inference. I realize there may have been a better way to do by changing load_checkpoint but this was easier to write and understand for me.

This reduced my model sizes from ~330MB to ~110MB

Thank you again so much for your help!