jaywalnut310 / glow-tts

A Generative Flow for Text-to-Speech via Monotonic Alignment Search
MIT License
667 stars 150 forks source link

GPU required or CPU-compatible? #57

Open lkurlandski opened 3 years ago

lkurlandski commented 3 years ago

I do not have a GPU. I would like to use the pretrained model only. Is there a way to do this on CPU?

I'm somewhat unfamiliar with deep learning packages. I see repeated calls and references to cuda. Therefore, I assumed that this can only be used on a machine with GPU.

JunjieLl commented 2 years ago

Of course, this is train. py code, you need to set "fp16_run" in config to false.

import os
import json
import argparse
import math
import torch
from torch import device, nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.multiprocessing as mp
# import torch.distributed as dist
# from apex.parallel import DistributedDataParallel as DDP
# from apex import amp

from data_utils import TextMelLoader, TextMelCollate
import models
import commons
import utils
from text.symbols import symbols

global_step = 0

device = torch.device("cpu")

def main():
  """Assume Single Node Multi GPUs Training Only"""
  assert torch.cuda.is_available(), "CPU training is not allowed."

  n_gpus = torch.cuda.device_count()
  os.environ['MASTER_ADDR'] = 'localhost'
  os.environ['MASTER_PORT'] = '50000'

  hps = utils.get_hparams()
  train_and_eval(n_gpus, hps)
  # mp.spawn(train_and_eval, nprocs=n_gpus, args=(n_gpus, hps,))

def train_and_eval(n_gpus, hps):
  global global_step
  # if rank == 0:
  logger = utils.get_logger(hps.model_dir)
  logger.info(hps)
  utils.check_git_hash(hps.model_dir)
  writer = SummaryWriter(log_dir=hps.model_dir)
  writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))

  # dist.init_process_group(backend='nccl', init_method='env://', world_size=n_gpus, rank=rank)
  torch.manual_seed(hps.train.seed)
  # torch.cuda.set_device(rank)

  train_dataset = TextMelLoader(hps.data.training_files, hps.data)
  # train_sampler = torch.utils.data.distributed.DistributedSampler(
  #     train_dataset,
  #     num_replicas=n_gpus,
  #     rank=rank,
  #     shuffle=True)
  collate_fn = TextMelCollate(1)
  train_loader = DataLoader(train_dataset, num_workers=8, shuffle=False,
      batch_size=hps.train.batch_size, pin_memory=True,
      drop_last=True, collate_fn=collate_fn)
  # if rank == 0:
  val_dataset = TextMelLoader(hps.data.validation_files, hps.data)
  val_loader = DataLoader(val_dataset, num_workers=8, shuffle=False,
        batch_size=hps.train.batch_size, pin_memory=True,
        drop_last=True, collate_fn=collate_fn)

  generator = models.FlowGenerator(
      n_vocab=len(symbols) + getattr(hps.data, "add_blank", False), 
      out_channels=hps.data.n_mel_channels, 
      **hps.model).to(device)
  optimizer_g = commons.Adam(generator.parameters(), scheduler=hps.train.scheduler, dim_model=hps.model.hidden_channels, warmup_steps=hps.train.warmup_steps, lr=hps.train.learning_rate, betas=hps.train.betas, eps=hps.train.eps)
  if hps.train.fp16_run:
    generator, optimizer_g._optim = amp.initialize(generator, optimizer_g._optim, opt_level="O1")
  # generator = DDP(generator)
  epoch_str = 1
  global_step = 0
  try:
    _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), generator, optimizer_g)
    epoch_str += 1
    optimizer_g.step_num = (epoch_str - 1) * len(train_loader)
    optimizer_g._update_learning_rate()
    global_step = (epoch_str - 1) * len(train_loader)
  except:
    if hps.train.ddi and os.path.isfile(os.path.join(hps.model_dir, "ddi_G.pth")):
      logger.info("load ddi")
      _ = utils.load_checkpoint(os.path.join(hps.model_dir, "ddi_G.pth"), generator, optimizer_g)

  for epoch in range(epoch_str, hps.train.epochs + 1):
    # if rank==0:
    train(epoch, hps, generator, optimizer_g, train_loader, logger, writer)
    evaluate(epoch, hps, generator, optimizer_g, val_loader, logger, writer_eval)
    if epoch%1000==0:
      utils.save_checkpoint(generator, optimizer_g, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(epoch)))
    # else:
    #   train(rank, epoch, hps, generator, optimizer_g, train_loader, None, None)

def train(epoch, hps, generator, optimizer_g, train_loader, logger, writer):
  # train_loader.sampler.set_epoch(epoch)
  global global_step

  generator.train()
  for batch_idx, (x, x_lengths, y, y_lengths) in enumerate(train_loader):
    x, x_lengths = x.to(device), x_lengths.to(device)
    y, y_lengths = y.to(device), y_lengths.to(device)

    # Train Generator
    optimizer_g.zero_grad()

    (z, z_m, z_logs, logdet, z_mask), (x_m, x_logs, x_mask), (attn, logw, logw_) = generator(x, x_lengths, y, y_lengths, gen=False)
    l_mle = commons.mle_loss(z, z_m, z_logs, logdet, z_mask)
    l_length = commons.duration_loss(logw, logw_, x_lengths)

    loss_gs = [l_mle, l_length]
    loss_g = sum(loss_gs)

    if hps.train.fp16_run:
      with amp.scale_loss(loss_g, optimizer_g._optim) as scaled_loss:
        scaled_loss.backward()
      grad_norm = commons.clip_grad_value_(amp.master_params(optimizer_g._optim), 5)
    else:
      loss_g.backward()
      grad_norm = commons.clip_grad_value_(generator.parameters(), 5)
    optimizer_g.step()

    # if rank==0:
    if batch_idx % hps.train.log_interval == 0:
      (y_gen, *_), *_ = generator(x[:1], x_lengths[:1], gen=True)
      logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(x), len(train_loader.dataset),
        100. * batch_idx / len(train_loader),
        loss_g.item()))
      logger.info([x.item() for x in loss_gs] + [global_step, optimizer_g.get_lr()])

      scalar_dict = {"loss/g/total": loss_g, "learning_rate": optimizer_g.get_lr(), "grad_norm": grad_norm}
      scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(loss_gs)})
      utils.summarize(
        writer=writer,
        global_step=global_step, 
        images={"y_org": utils.plot_spectrogram_to_numpy(y[0].data.cpu().numpy()), 
          "y_gen": utils.plot_spectrogram_to_numpy(y_gen[0].data.cpu().numpy()), 
          "attn": utils.plot_alignment_to_numpy(attn[0,0].data.cpu().numpy()),
          },
        scalars=scalar_dict)
    global_step += 1

  # if rank == 0:
  logger.info('====> Epoch: {}'.format(epoch))

def evaluate(epoch, hps, generator, optimizer_g, val_loader, logger, writer_eval):
  # if rank == 0:
  global global_step
  generator.eval()
  losses_tot = []
  with torch.no_grad():
    for batch_idx, (x, x_lengths, y, y_lengths) in enumerate(val_loader):
      x, x_lengths = x.to(device), x_lengths.to(device)
      y, y_lengths = y.to(device), y_lengths.to(device)

      (z, z_m, z_logs, logdet, z_mask), (x_m, x_logs, x_mask), (attn, logw, logw_) = generator(x, x_lengths, y, y_lengths, gen=False)
      l_mle = commons.mle_loss(z, z_m, z_logs, logdet, z_mask)
      l_length = commons.duration_loss(logw, logw_, x_lengths)

      loss_gs = [l_mle, l_length]
      loss_g = sum(loss_gs)

      if batch_idx == 0:
        losses_tot = loss_gs
      else:
        losses_tot = [x + y for (x, y) in zip(losses_tot, loss_gs)]

      if batch_idx % hps.train.log_interval == 0:
        logger.info('Eval Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
          epoch, batch_idx * len(x), len(val_loader.dataset),
          100. * batch_idx / len(val_loader),
          loss_g.item()))
        logger.info([x.item() for x in loss_gs])

    losses_tot = [x/len(val_loader) for x in losses_tot]
    loss_tot = sum(losses_tot)
    scalar_dict = {"loss/g/total": loss_tot}
    scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_tot)})
    utils.summarize(
      writer=writer_eval,
      global_step=global_step, 
      scalars=scalar_dict)
    logger.info('====> Epoch: {}'.format(epoch))

if __name__ == "__main__":
  main()