Closed roedoejet closed 2 months ago
This is silly, but I tried to just move everything manually...it doesn't seem to work. training doesn't resume at the right epoch. There's probably something else that is mismatched but I can't figure out what.
import torch
from tqdm import tqdm
import regex as re
ev_ckpt = torch.load('/home/aip000/sgile/data/CRK/ev-sue-config/logs_and_checkpoints/VocoderExperiment/base/checkpoints/epoch=0-step=176.ckpt')
hfg_ckpt_d = torch.load('/home/aip000/tts/models/HiFi-GAN-TTS.pretrained/UNIVERSAL_V1/do_02500000')
hfg_ckpt_g = torch.load('/home/aip000/tts/models/HiFi-GAN-TTS.pretrained/UNIVERSAL_V1/g_02500000')
# Set optimizer states
ev_ckpt['optimizer_states'][0]['state'] = hfg_ckpt_d['optim_g']['state']
ev_ckpt['optimizer_states'][0]['param_groups'] = hfg_ckpt_d['optim_g']['param_groups']
ev_ckpt['optimizer_states'][1]['state'] = hfg_ckpt_d['optim_d']['state']
ev_ckpt['optimizer_states'][1]['param_groups'] = hfg_ckpt_d['optim_d']['param_groups']
# map generator/discriminator states
hfg_g_state_map = {k: f'generator.{k}' for k in hfg_ckpt_g['generator'].keys()} # key is old key, value is corresponding key in EV
hfg_msd_state_map = {k: f'msd.{k}' for k in hfg_ckpt_d['msd'].keys()} # key is old key, value is corresponding key in EV
hfg_mpd_state_map = {k: f'mpd.{k}' for k in hfg_ckpt_d['mpd'].keys()} # key is old key, value is corresponding key in EV
msd_pattern = re.compile(r'msd\.discriminators\.[1-9]')
mpd_pattern = re.compile(r'mpd\.discriminators')
resblock_pattern = re.compile(r'generator\.resblocks')
upsample_pattern = re.compile(r'generator\.ups')
convpre_pattern = re.compile(r'generator\.conv_p')
def add_parametrizations(pieces):
# add parametrizations
pieces.insert(-1, 'parametrizations')
# add weight
pieces.insert(-1, 'weight')
return pieces
def fix_weight_keys(map, patterns):
for k, v in tqdm(map.items()):
matches = [re.match(pattern, v) for pattern in patterns]
if any(matches):
pieces = v.split('.')
# convert weight_g/weight_v to original0/original1
if pieces[-1] == 'weight_g':
pieces[-1] = 'original0'
pieces = add_parametrizations(pieces)
elif pieces[-1] == 'weight_v':
pieces[-1] = 'original1'
pieces = add_parametrizations(pieces)
else:
continue
v = '.'.join(pieces)
map[k] = v
return map
hfg_g_state_map = fix_weight_keys(hfg_g_state_map, [resblock_pattern, upsample_pattern, convpre_pattern])
hfg_msd_state_map = fix_weight_keys(hfg_msd_state_map, [msd_pattern])
hfg_mpd_state_map = fix_weight_keys(hfg_mpd_state_map, [mpd_pattern])
gen_hfg_states = set(hfg_g_state_map.values())
msd_hfg_states = set(hfg_msd_state_map.values())
mpd_hfg_states = set(hfg_mpd_state_map.values())
gen_ev_states = set([x for x in ev_ckpt['state_dict'].keys() if x.startswith('generator')])
msd_ev_states = set([x for x in ev_ckpt['state_dict'].keys() if x.startswith('msd')])
mpd_ev_states = set([x for x in ev_ckpt['state_dict'].keys() if x.startswith('mpd')])
assert len(gen_ev_states) == len(gen_hfg_states) == 234
assert len(hfg_msd_state_map) == len(msd_ev_states) == 80
assert len(hfg_mpd_state_map) == len(mpd_ev_states) == 90
assert len(gen_ev_states - gen_hfg_states) == 0
assert len(msd_ev_states - msd_hfg_states) == 0
assert len(mpd_ev_states - mpd_hfg_states) == 0
# replace the weights
for k, v in hfg_g_state_map.items():
ev_ckpt['state_dict'][v] = hfg_ckpt_g['generator'][k]
for k, v in hfg_msd_state_map.items():
ev_ckpt['state_dict'][v] = hfg_ckpt_d['msd'][k]
for k, v in hfg_mpd_state_map.items():
ev_ckpt['state_dict'][v] = hfg_ckpt_d['mpd'][k]
# callbacks - meh? these don't seem like we need to move them
# Set steps
ev_ckpt['global_step'] = 5000000
# Set epoch
ev_ckpt['epoch'] = 12400
# loops (from EV universal)
ev_ckpt['loops'] = {'fit_loop': {'state_dict': {}, 'epoch_loop.state_dict': {'_batches_that_stepped': 2500000}, 'epoch_loop.batch_progress': {'total': {'ready': 2500000, 'completed': 2500000, 'started': 2500000, 'processed': 2500000}, 'current': {'ready': 12400, 'completed': 12400, 'started': 12400, 'processed': 12400}, 'is_last_batch': False}, 'epoch_loop.scheduler_progress': {'total': {'ready': 240, 'completed': 240}, 'current': {'ready': 0, 'completed': 0}}, 'epoch_loop.batch_loop.state_dict': {}, 'epoch_loop.batch_loop.optimizer_loop.state_dict': {}, 'epoch_loop.batch_loop.optimizer_loop.optim_progress': {'optimizer': {'step': {'total': {'ready': 5000000, 'completed': 5000000}, 'current': {'ready': 24800, 'completed': 24800}}, 'zero_grad': {'total': {'ready': 5000000, 'completed': 5000000, 'started': 5000000}, 'current': {'ready': 24800, 'completed': 24800, 'started': 24800}}}, 'optimizer_position': 2}, 'epoch_loop.batch_loop.manual_loop.state_dict': {}, 'epoch_loop.batch_loop.manual_loop.optim_step_progress': {'total': {'ready': 0, 'completed': 0}, 'current': {'ready': 0, 'completed': 0}}, 'epoch_loop.val_loop.state_dict': {}, 'epoch_loop.val_loop.dataloader_progress': {'total': {'ready': 120, 'completed': 120}, 'current': {'ready': 1, 'completed': 1}}, 'epoch_loop.val_loop.epoch_loop.state_dict': {}, 'epoch_loop.val_loop.epoch_loop.batch_progress': {'total': {'ready': 0, 'completed': 0, 'started': 0, 'processed': 0}, 'current': {'ready': 36856, 'completed': 36856, 'started': 36856, 'processed': 36856}, 'is_last_batch': True}, 'epoch_progress': {'total': {'ready': 121, 'completed': 120, 'started': 121, 'processed': 121}, 'current': {'ready': 121, 'completed': 120, 'started': 121, 'processed': 121}}}, 'validate_loop': {'state_dict': {}, 'dataloader_progress': {'total': {'ready': 0, 'completed': 0}, 'current': {'ready': 0, 'completed': 0}}, 'epoch_loop.state_dict': {}, 'epoch_loop.batch_progress': {'total': {'ready': 0, 'completed': 0, 'started': 0, 'processed': 0}, 'current': {'ready': 0, 'completed': 0, 'started': 0, 'processed': 0}, 'is_last_batch': False}}, 'test_loop': {'state_dict': {}, 'dataloader_progress': {'total': {'ready': 0, 'completed': 0}, 'current': {'ready': 0, 'completed': 0}}, 'epoch_loop.state_dict': {}, 'epoch_loop.batch_progress': {'total': {'ready': 0, 'completed': 0, 'started': 0, 'processed': 0}, 'current': {'ready': 0, 'completed': 0, 'started': 0, 'processed': 0}, 'is_last_batch': False}}, 'predict_loop': {'state_dict': {}, 'dataloader_progress': {'total': {'ready': 0, 'completed': 0}, 'current': {'ready': 0, 'completed': 0}}, 'epoch_loop.state_dict': {}, 'epoch_loop.batch_progress': {'total': {'ready': 0, 'completed': 0, 'started': 0, 'processed': 0}, 'current': {'ready': 0, 'completed': 0, 'started': 0, 'processed': 0}}}}
# LR Schedulers
ev_ckpt['hyper_parameters']['config']['training']['optimizer']['learning_rate'] = 0.0002
ev_ckpt['lr_schedulers'][0]['last_epoch'] = ev_ckpt['epoch']
ev_ckpt['lr_schedulers'][0]['_step_count'] = ev_ckpt['global_step']
torch.save(ev_ckpt, 'ev_hfg.ckpt')
To keep a record of what was done.
Get the reference code git clone https://github.com/jik876/hifi-gan
.
The Generator were outputing different values so we reproduced the forward()
s and gradually increased the number of upsamples
& kernels
.
We then decided to extend the work to check if the discriminators were also correctly reimplemented. Contrary to the Generator, the discriminators produced the same outputs as the reference code.
Here's the code that was used to check the Models. Note that the code is highly hacky and is only intended to serve the investigation and not to be beautiful and production ready.
#!/usr/bin/env python3
import json
import regex as re
import torch
from everyvoice.model.vocoder.HiFiGAN_iSTFT_lightning.hfgl.config import HiFiGANConfig
from everyvoice.model.vocoder.HiFiGAN_iSTFT_lightning.hfgl.model import HiFiGAN
from tqdm import tqdm
from hifi_gan.env import AttrDict
from hifi_gan.models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator
# from torch.nn.utils.parametrizations import weight_norm
# from torch.nn.utils.parametrize import remove_parametrizations
LRELU_SLOPE = 0.1
EV_CKPT = "ev_hfg.ckpt"
EV_TEMPLATE_FILENAME = "/home/aip000/sgile/data/CRK/ev-sue-config/logs_and_checkpoints/VocoderExperiment/base/checkpoints/epoch=0-step=176.ckpt"
JIK876_GENERATOR_FILENAME = (
"/home/aip000/tts/models/HiFi-GAN-TTS.pretrained/UNIVERSAL_V1/g_02500000"
)
JIK876_DISCRIMINATOR_FILENAME = (
"/home/aip000/tts/models/HiFi-GAN-TTS.pretrained/UNIVERSAL_V1/do_02500000"
)
def add_parametrizations(pieces):
# add parametrizations
pieces.insert(-1, "parametrizations")
# add weight
pieces.insert(-1, "weight")
return pieces
def fix_weight_keys(map, patterns):
for k, v in tqdm(map.items()):
matches = [re.match(pattern, v) for pattern in patterns]
if any(matches):
pieces = v.split(".")
# convert weight_g/weight_v to original0/original1
if pieces[-1] == "weight_g":
pieces[-1] = "original0"
pieces = add_parametrizations(pieces)
elif pieces[-1] == "weight_v":
pieces[-1] = "original1"
pieces = add_parametrizations(pieces)
else:
continue
v = ".".join(pieces)
map[k] = v
return map
def fix_loop(ev_ckpt):
# loops (from EV universal)
ev_ckpt["loops"] = {
"fit_loop": {
"state_dict": {},
"epoch_loop.state_dict": {"_batches_that_stepped": 2500000},
"epoch_loop.batch_progress": {
"total": {
"ready": 2500000,
"completed": 2500000,
"started": 2500000,
"processed": 2500000,
},
"current": {
"ready": 12400,
"completed": 12400,
"started": 12400,
"processed": 12400,
},
"is_last_batch": False,
},
"epoch_loop.scheduler_progress": {
"total": {"ready": 240, "completed": 240},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.batch_loop.state_dict": {},
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
"optimizer": {
"step": {
"total": {"ready": 5000000, "completed": 5000000},
"current": {"ready": 24800, "completed": 24800},
},
"zero_grad": {
"total": {
"ready": 5000000,
"completed": 5000000,
"started": 5000000,
},
"current": {
"ready": 24800,
"completed": 24800,
"started": 24800,
},
},
},
"optimizer_position": 2,
},
"epoch_loop.batch_loop.manual_loop.state_dict": {},
"epoch_loop.batch_loop.manual_loop.optim_step_progress": {
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.val_loop.state_dict": {},
"epoch_loop.val_loop.dataloader_progress": {
"total": {"ready": 120, "completed": 120},
"current": {"ready": 1, "completed": 1},
},
"epoch_loop.val_loop.epoch_loop.state_dict": {},
"epoch_loop.val_loop.epoch_loop.batch_progress": {
"total": {"ready": 0, "completed": 0, "started": 0, "processed": 0},
"current": {
"ready": 36856,
"completed": 36856,
"started": 36856,
"processed": 36856,
},
"is_last_batch": True,
},
"epoch_progress": {
"total": {
"ready": 121,
"completed": 120,
"started": 121,
"processed": 121,
},
"current": {
"ready": 121,
"completed": 120,
"started": 121,
"processed": 121,
},
},
},
"validate_loop": {
"state_dict": {},
"dataloader_progress": {
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.state_dict": {},
"epoch_loop.batch_progress": {
"total": {"ready": 0, "completed": 0, "started": 0, "processed": 0},
"current": {"ready": 0, "completed": 0, "started": 0, "processed": 0},
"is_last_batch": False,
},
},
"test_loop": {
"state_dict": {},
"dataloader_progress": {
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.state_dict": {},
"epoch_loop.batch_progress": {
"total": {"ready": 0, "completed": 0, "started": 0, "processed": 0},
"current": {"ready": 0, "completed": 0, "started": 0, "processed": 0},
"is_last_batch": False,
},
},
"predict_loop": {
"state_dict": {},
"dataloader_progress": {
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.state_dict": {},
"epoch_loop.batch_progress": {
"total": {"ready": 0, "completed": 0, "started": 0, "processed": 0},
"current": {"ready": 0, "completed": 0, "started": 0, "processed": 0},
},
},
}
def fix_non_state_dict(ev_ckpt):
fix_loop(ev_ckpt)
# Set steps
ev_ckpt["global_step"] = 5000000
# Set epoch
ev_ckpt["epoch"] = 12400
# LR Schedulers
ev_ckpt["hyper_parameters"]["config"]["training"]["optimizer"][
"learning_rate"
] = 0.0002
ev_ckpt["lr_schedulers"][0]["last_epoch"] = ev_ckpt["epoch"]
ev_ckpt["lr_schedulers"][0]["_step_count"] = ev_ckpt["global_step"]
def convert():
"""
Convert a jik876 Generator into a EveryVoice Generator.
"""
ev_ckpt = torch.load(
EV_TEMPLATE_FILENAME,
map_location=torch.device("cpu"),
)
hfg_ckpt_d = torch.load(
JIK876_DISCRIMINATOR_FILENAME,
map_location=torch.device("cpu"),
)
hfg_ckpt_g = torch.load(
JIK876_GENERATOR_FILENAME,
map_location=torch.device("cpu"),
)
state_dict_keys = set(ev_ckpt["state_dict"].keys())
print(*sorted(state_dict_keys), sep="\n")
# Set optimizer states
ev_ckpt["optimizer_states"][0]["state"] = hfg_ckpt_d["optim_g"]["state"]
ev_ckpt["optimizer_states"][0]["param_groups"] = hfg_ckpt_d["optim_g"][
"param_groups"
]
ev_ckpt["optimizer_states"][1]["state"] = hfg_ckpt_d["optim_d"]["state"]
ev_ckpt["optimizer_states"][1]["param_groups"] = hfg_ckpt_d["optim_d"][
"param_groups"
]
# map generator/discriminator states
hfg_g_state_map = {
k: f"generator.{k}" for k in hfg_ckpt_g["generator"].keys()
} # key is old key, value is corresponding key in EV
hfg_msd_state_map = {
k: f"msd.{k}" for k in hfg_ckpt_d["msd"].keys()
} # key is old key, value is corresponding key in EV
hfg_mpd_state_map = {
k: f"mpd.{k}" for k in hfg_ckpt_d["mpd"].keys()
} # key is old key, value is corresponding key in EV
msd_pattern = re.compile(r"msd\.discriminators\.[1-9]")
mpd_pattern = re.compile(r"mpd\.discriminators")
resblock_pattern = re.compile(r"generator\.resblocks")
upsample_pattern = re.compile(r"generator\.ups")
convpre_pattern = re.compile(r"generator\.conv_p")
hfg_g_state_map = fix_weight_keys(
hfg_g_state_map, [resblock_pattern, upsample_pattern, convpre_pattern]
)
hfg_msd_state_map = fix_weight_keys(hfg_msd_state_map, [msd_pattern])
hfg_mpd_state_map = fix_weight_keys(hfg_mpd_state_map, [mpd_pattern])
gen_hfg_states = set(hfg_g_state_map.values())
msd_hfg_states = set(hfg_msd_state_map.values())
mpd_hfg_states = set(hfg_mpd_state_map.values())
gen_ev_states = set(
[x for x in ev_ckpt["state_dict"].keys() if x.startswith("generator")]
)
msd_ev_states = set(
[x for x in ev_ckpt["state_dict"].keys() if x.startswith("msd")]
)
mpd_ev_states = set(
[x for x in ev_ckpt["state_dict"].keys() if x.startswith("mpd")]
)
assert len(gen_ev_states) == len(gen_hfg_states) == 234
assert len(hfg_msd_state_map) == len(msd_ev_states) == 80
assert len(hfg_mpd_state_map) == len(mpd_ev_states) == 90
assert len(gen_ev_states - gen_hfg_states) == 0
assert len(msd_ev_states - msd_hfg_states) == 0
assert len(mpd_ev_states - mpd_hfg_states) == 0
# replace the weights
for k, v in hfg_g_state_map.items():
ev_ckpt["state_dict"][v] = hfg_ckpt_g["generator"][k]
# print(ev_ckpt["state_dict"][v] is hfg_ckpt_g["generator"][k])
for k, v in hfg_msd_state_map.items():
ev_ckpt["state_dict"][v] = hfg_ckpt_d["msd"][k]
for k, v in hfg_mpd_state_map.items():
ev_ckpt["state_dict"][v] = hfg_ckpt_d["mpd"][k]
state_dict_keys -= gen_hfg_states | msd_hfg_states | mpd_hfg_states
# print(state_dict_keys)
# callbacks - meh? these don't seem like we need to move them
fix_non_state_dict(ev_ckpt)
# print(ev_ckpt.keys())
# print(json.dumps(ev_ckpt["hyper_parameters"], indent=2, ensure_ascii=False))
torch.save(ev_ckpt, EV_CKPT)
def test_generator():
"""
Let figure out where our Generator code differs from jik876.
"""
ev_ckpt = torch.load(EV_CKPT)
config: dict | HiFiGANConfig = ev_ckpt["hyper_parameters"]["config"]
if isinstance(config, dict):
config = HiFiGANConfig(**config)
ev_generator = HiFiGAN(config)
ev_generator.load_state_dict(ev_ckpt["state_dict"])
ev_generator.generator.remove_weight_norm()
# model.mpd.remove_weight_norm()
# model.msd.remove_weight_norm()
# remove_parametrizations(sub_layer, "weight")
# print(ev_generator)
with open("hifi_gan/config_v1.json", mode="r") as cin:
hyperparameters = AttrDict(json.load(cin))
jik876 = Generator(hyperparameters)
jik876.load_state_dict(
torch.load(
JIK876_GENERATOR_FILENAME,
map_location=torch.device("cpu"),
)["generator"]
)
# print(jik876)
input_ = torch.rand((2, 80, 120))
first = slice(None, 10)
if False:
print("num_kernels, num_upsamples")
print(ev_generator.generator.num_kernels, ev_generator.generator.num_upsamples)
print(jik876.num_kernels, jik876.num_upsamples)
if False:
import torch.nn.functional as F
print("activation_function, leaky_relu")
print(config.model.activation_function)
a = ev_generator.config.model.activation_function(input_)
print(a)
b = F.leaky_relu(input_)
print(b)
assert torch.all(torch.eq(a, b))
if False:
print("conv_pre.bias")
print(ev_generator.generator.conv_pre.bias[first])
print(jik876.conv_pre.bias[first])
if False:
print("ups[0].bias")
print(ev_generator.generator.ups[0].bias[first])
print(jik876.ups[0].bias[first])
assert torch.all(
torch.eq(ev_generator.generator.ups[0].bias, jik876.ups[0].bias)
)
if False:
print("conv_pre()")
ev_generator.generator.conv_pre = jik876.conv_pre
a = ev_generator.generator.conv_pre(input_)
print(a, a.shape)
a = jik876.conv_pre(input_)
print(a, a.shape)
if False:
import torch.nn.functional as F
num_upsamples = ev_generator.generator.num_upsamples
num_kernels = ev_generator.generator.num_kernels
print("Partial")
a = ev_generator.generator.conv_pre(input_)
for i in range(num_upsamples):
a = ev_generator.generator.config.model.activation_function(a)
a = ev_generator.generator.ups[i](a)
xs = None
for j in range(num_kernels):
if xs is None:
xs = ev_generator.generator.resblocks[i * num_kernels + j](a)
else:
xs += ev_generator.generator.resblocks[i * num_kernels + j](a)
# print(j, xs[:, :2, :4])
a = xs / num_kernels
print(i, a[:, :2, :4])
b = jik876.conv_pre(input_)
for i in range(num_upsamples):
b = F.leaky_relu(b, LRELU_SLOPE)
b = jik876.ups[i](b)
xs = None
for j in range(num_kernels):
if xs is None:
xs = jik876.resblocks[i * num_kernels + j](b)
else:
xs += jik876.resblocks[i * num_kernels + j](b)
# print(j, xs[:, :2, :4])
b = xs / num_kernels
print(i, b[:, :2, :4])
# print(a)
# print(b)
if True:
print("Full run")
a = ev_generator.generator(input_)
print(a[:, :2, :4], a.shape)
b = jik876(input_)
print(b[:, :2, :4], b.shape)
print(torch.eq(a, b)[:, :2, :4])
print((a - b)[:, :2, :4])
assert torch.allclose(a, b)
assert torch.all(torch.eq(a, b))
def test_discriminator():
y = torch.rand((2, 1, 8192))
y_hat = torch.rand((2, 1, 8192))
ev_ckpt = torch.load(EV_CKPT)
config: dict | HiFiGANConfig = ev_ckpt["hyper_parameters"]["config"]
if isinstance(config, dict):
config = HiFiGANConfig(**config)
ev = HiFiGAN(config)
ev.load_state_dict(ev_ckpt["state_dict"])
ev.generator.remove_weight_norm()
hfg_ckpt_d = torch.load(
JIK876_DISCRIMINATOR_FILENAME,
map_location=torch.device("cpu"),
)
mpd = MultiPeriodDiscriminator()
mpd.load_state_dict(hfg_ckpt_d["mpd"])
msd = MultiScaleDiscriminator()
msd.load_state_dict(hfg_ckpt_d["msd"])
def equal_helper(a, b):
"""Recursively walk list then check tensor equality."""
if isinstance(a, (list, tuple)):
for i, (c, d) in enumerate(zip(a, b)):
equal_helper(c, d)
else:
assert torch.all(torch.eq(a, b))
equal_helper(ev.mpd(y, y_hat), mpd(y, y_hat))
equal_helper(ev.msd(y, y_hat), msd(y, y_hat))
if __name__ == "__main__":
convert()
test_generator()
test_discriminator()
The Universal HiFiGAN checkpoint is very good, we could possibly just move the weights from it over to a shell EV model.