EveryVoiceTTS / EveryVoice

The EveryVoice TTS Toolkit - Text To Speech for your language
https://docs.everyvoice.ca
Other
12 stars 0 forks source link

HiFiGAN UNIVERSAL checkpoint in EV #418

Closed roedoejet closed 2 months ago

roedoejet commented 2 months ago

The Universal HiFiGAN checkpoint is very good, we could possibly just move the weights from it over to a shell EV model.

roedoejet commented 2 months ago

This is silly, but I tried to just move everything manually...it doesn't seem to work. training doesn't resume at the right epoch. There's probably something else that is mismatched but I can't figure out what.

import torch
from tqdm import tqdm
import regex as re

ev_ckpt = torch.load('/home/aip000/sgile/data/CRK/ev-sue-config/logs_and_checkpoints/VocoderExperiment/base/checkpoints/epoch=0-step=176.ckpt')
hfg_ckpt_d = torch.load('/home/aip000/tts/models/HiFi-GAN-TTS.pretrained/UNIVERSAL_V1/do_02500000')
hfg_ckpt_g = torch.load('/home/aip000/tts/models/HiFi-GAN-TTS.pretrained/UNIVERSAL_V1/g_02500000')

# Set optimizer states
ev_ckpt['optimizer_states'][0]['state'] = hfg_ckpt_d['optim_g']['state']
ev_ckpt['optimizer_states'][0]['param_groups'] = hfg_ckpt_d['optim_g']['param_groups']
ev_ckpt['optimizer_states'][1]['state'] = hfg_ckpt_d['optim_d']['state']
ev_ckpt['optimizer_states'][1]['param_groups'] = hfg_ckpt_d['optim_d']['param_groups']

# map generator/discriminator states
hfg_g_state_map = {k: f'generator.{k}' for k in hfg_ckpt_g['generator'].keys()} # key is old key, value is corresponding key in EV
hfg_msd_state_map = {k: f'msd.{k}' for k in hfg_ckpt_d['msd'].keys()} # key is old key, value is corresponding key in EV
hfg_mpd_state_map = {k: f'mpd.{k}' for k in hfg_ckpt_d['mpd'].keys()} # key is old key, value is corresponding key in EV
msd_pattern = re.compile(r'msd\.discriminators\.[1-9]')
mpd_pattern = re.compile(r'mpd\.discriminators')
resblock_pattern = re.compile(r'generator\.resblocks')
upsample_pattern = re.compile(r'generator\.ups')
convpre_pattern = re.compile(r'generator\.conv_p')

def add_parametrizations(pieces):
    # add parametrizations
    pieces.insert(-1, 'parametrizations')
    # add weight
    pieces.insert(-1, 'weight')
    return pieces

def fix_weight_keys(map, patterns):
    for k, v in tqdm(map.items()):
        matches = [re.match(pattern, v) for pattern in patterns]
        if any(matches):
            pieces = v.split('.')
            # convert weight_g/weight_v to original0/original1
            if pieces[-1] == 'weight_g':
                pieces[-1] = 'original0'
                pieces = add_parametrizations(pieces)
            elif pieces[-1] == 'weight_v':
                pieces[-1] = 'original1'
                pieces = add_parametrizations(pieces)
            else:
                continue
            v = '.'.join(pieces)
        map[k] = v
    return map

hfg_g_state_map = fix_weight_keys(hfg_g_state_map, [resblock_pattern, upsample_pattern, convpre_pattern])
hfg_msd_state_map = fix_weight_keys(hfg_msd_state_map, [msd_pattern])
hfg_mpd_state_map = fix_weight_keys(hfg_mpd_state_map, [mpd_pattern])

gen_hfg_states = set(hfg_g_state_map.values())
msd_hfg_states = set(hfg_msd_state_map.values())
mpd_hfg_states = set(hfg_mpd_state_map.values())

gen_ev_states = set([x for x in ev_ckpt['state_dict'].keys() if x.startswith('generator')])
msd_ev_states = set([x for x in ev_ckpt['state_dict'].keys() if x.startswith('msd')])
mpd_ev_states = set([x for x in ev_ckpt['state_dict'].keys() if x.startswith('mpd')])

assert len(gen_ev_states) == len(gen_hfg_states) == 234
assert len(hfg_msd_state_map) == len(msd_ev_states) == 80
assert len(hfg_mpd_state_map) == len(mpd_ev_states) == 90

assert len(gen_ev_states - gen_hfg_states) == 0
assert len(msd_ev_states - msd_hfg_states) == 0
assert len(mpd_ev_states - mpd_hfg_states) == 0

# replace the weights
for k, v in hfg_g_state_map.items():
    ev_ckpt['state_dict'][v] = hfg_ckpt_g['generator'][k]

for k, v in hfg_msd_state_map.items():
    ev_ckpt['state_dict'][v] = hfg_ckpt_d['msd'][k]

for k, v in hfg_mpd_state_map.items():
    ev_ckpt['state_dict'][v] = hfg_ckpt_d['mpd'][k]

# callbacks - meh? these don't seem like we need to move them

# Set steps
ev_ckpt['global_step'] = 5000000

# Set epoch
ev_ckpt['epoch'] = 12400

# loops (from EV universal)
ev_ckpt['loops'] = {'fit_loop': {'state_dict': {}, 'epoch_loop.state_dict': {'_batches_that_stepped': 2500000}, 'epoch_loop.batch_progress': {'total': {'ready': 2500000, 'completed': 2500000, 'started': 2500000, 'processed': 2500000}, 'current': {'ready': 12400, 'completed': 12400, 'started': 12400, 'processed': 12400}, 'is_last_batch': False}, 'epoch_loop.scheduler_progress': {'total': {'ready': 240, 'completed': 240}, 'current': {'ready': 0, 'completed': 0}}, 'epoch_loop.batch_loop.state_dict': {}, 'epoch_loop.batch_loop.optimizer_loop.state_dict': {}, 'epoch_loop.batch_loop.optimizer_loop.optim_progress': {'optimizer': {'step': {'total': {'ready': 5000000, 'completed': 5000000}, 'current': {'ready': 24800, 'completed': 24800}}, 'zero_grad': {'total': {'ready': 5000000, 'completed': 5000000, 'started': 5000000}, 'current': {'ready': 24800, 'completed': 24800, 'started': 24800}}}, 'optimizer_position': 2}, 'epoch_loop.batch_loop.manual_loop.state_dict': {}, 'epoch_loop.batch_loop.manual_loop.optim_step_progress': {'total': {'ready': 0, 'completed': 0}, 'current': {'ready': 0, 'completed': 0}}, 'epoch_loop.val_loop.state_dict': {}, 'epoch_loop.val_loop.dataloader_progress': {'total': {'ready': 120, 'completed': 120}, 'current': {'ready': 1, 'completed': 1}}, 'epoch_loop.val_loop.epoch_loop.state_dict': {}, 'epoch_loop.val_loop.epoch_loop.batch_progress': {'total': {'ready': 0, 'completed': 0, 'started': 0, 'processed': 0}, 'current': {'ready': 36856, 'completed': 36856, 'started': 36856, 'processed': 36856}, 'is_last_batch': True}, 'epoch_progress': {'total': {'ready': 121, 'completed': 120, 'started': 121, 'processed': 121}, 'current': {'ready': 121, 'completed': 120, 'started': 121, 'processed': 121}}}, 'validate_loop': {'state_dict': {}, 'dataloader_progress': {'total': {'ready': 0, 'completed': 0}, 'current': {'ready': 0, 'completed': 0}}, 'epoch_loop.state_dict': {}, 'epoch_loop.batch_progress': {'total': {'ready': 0, 'completed': 0, 'started': 0, 'processed': 0}, 'current': {'ready': 0, 'completed': 0, 'started': 0, 'processed': 0}, 'is_last_batch': False}}, 'test_loop': {'state_dict': {}, 'dataloader_progress': {'total': {'ready': 0, 'completed': 0}, 'current': {'ready': 0, 'completed': 0}}, 'epoch_loop.state_dict': {}, 'epoch_loop.batch_progress': {'total': {'ready': 0, 'completed': 0, 'started': 0, 'processed': 0}, 'current': {'ready': 0, 'completed': 0, 'started': 0, 'processed': 0}, 'is_last_batch': False}}, 'predict_loop': {'state_dict': {}, 'dataloader_progress': {'total': {'ready': 0, 'completed': 0}, 'current': {'ready': 0, 'completed': 0}}, 'epoch_loop.state_dict': {}, 'epoch_loop.batch_progress': {'total': {'ready': 0, 'completed': 0, 'started': 0, 'processed': 0}, 'current': {'ready': 0, 'completed': 0, 'started': 0, 'processed': 0}}}}

# LR Schedulers
ev_ckpt['hyper_parameters']['config']['training']['optimizer']['learning_rate'] = 0.0002
ev_ckpt['lr_schedulers'][0]['last_epoch'] = ev_ckpt['epoch']
ev_ckpt['lr_schedulers'][0]['_step_count'] = ev_ckpt['global_step']

torch.save(ev_ckpt, 'ev_hfg.ckpt')
SamuelLarkin commented 2 months ago

https://github.com/jik876/hifi-gan

SamuelLarkin commented 2 months ago

To keep a record of what was done.

Get the reference code git clone https://github.com/jik876/hifi-gan.

The Generator were outputing different values so we reproduced the forward()s and gradually increased the number of upsamples & kernels.

We then decided to extend the work to check if the discriminators were also correctly reimplemented. Contrary to the Generator, the discriminators produced the same outputs as the reference code.

Code

Here's the code that was used to check the Models. Note that the code is highly hacky and is only intended to serve the investigation and not to be beautiful and production ready.

#!/usr/bin/env  python3

import json

import regex as re
import torch
from everyvoice.model.vocoder.HiFiGAN_iSTFT_lightning.hfgl.config import HiFiGANConfig
from everyvoice.model.vocoder.HiFiGAN_iSTFT_lightning.hfgl.model import HiFiGAN
from tqdm import tqdm

from hifi_gan.env import AttrDict
from hifi_gan.models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator

# from torch.nn.utils.parametrizations import weight_norm
# from torch.nn.utils.parametrize import remove_parametrizations

LRELU_SLOPE = 0.1
EV_CKPT = "ev_hfg.ckpt"
EV_TEMPLATE_FILENAME = "/home/aip000/sgile/data/CRK/ev-sue-config/logs_and_checkpoints/VocoderExperiment/base/checkpoints/epoch=0-step=176.ckpt"
JIK876_GENERATOR_FILENAME = (
    "/home/aip000/tts/models/HiFi-GAN-TTS.pretrained/UNIVERSAL_V1/g_02500000"
)
JIK876_DISCRIMINATOR_FILENAME = (
    "/home/aip000/tts/models/HiFi-GAN-TTS.pretrained/UNIVERSAL_V1/do_02500000"
)

def add_parametrizations(pieces):
    # add parametrizations
    pieces.insert(-1, "parametrizations")
    # add weight
    pieces.insert(-1, "weight")

    return pieces

def fix_weight_keys(map, patterns):
    for k, v in tqdm(map.items()):
        matches = [re.match(pattern, v) for pattern in patterns]
        if any(matches):
            pieces = v.split(".")
            # convert weight_g/weight_v to original0/original1
            if pieces[-1] == "weight_g":
                pieces[-1] = "original0"
                pieces = add_parametrizations(pieces)
            elif pieces[-1] == "weight_v":
                pieces[-1] = "original1"
                pieces = add_parametrizations(pieces)
            else:
                continue
            v = ".".join(pieces)
        map[k] = v

    return map

def fix_loop(ev_ckpt):
    # loops (from EV universal)
    ev_ckpt["loops"] = {
        "fit_loop": {
            "state_dict": {},
            "epoch_loop.state_dict": {"_batches_that_stepped": 2500000},
            "epoch_loop.batch_progress": {
                "total": {
                    "ready": 2500000,
                    "completed": 2500000,
                    "started": 2500000,
                    "processed": 2500000,
                },
                "current": {
                    "ready": 12400,
                    "completed": 12400,
                    "started": 12400,
                    "processed": 12400,
                },
                "is_last_batch": False,
            },
            "epoch_loop.scheduler_progress": {
                "total": {"ready": 240, "completed": 240},
                "current": {"ready": 0, "completed": 0},
            },
            "epoch_loop.batch_loop.state_dict": {},
            "epoch_loop.batch_loop.optimizer_loop.state_dict": {},
            "epoch_loop.batch_loop.optimizer_loop.optim_progress": {
                "optimizer": {
                    "step": {
                        "total": {"ready": 5000000, "completed": 5000000},
                        "current": {"ready": 24800, "completed": 24800},
                    },
                    "zero_grad": {
                        "total": {
                            "ready": 5000000,
                            "completed": 5000000,
                            "started": 5000000,
                        },
                        "current": {
                            "ready": 24800,
                            "completed": 24800,
                            "started": 24800,
                        },
                    },
                },
                "optimizer_position": 2,
            },
            "epoch_loop.batch_loop.manual_loop.state_dict": {},
            "epoch_loop.batch_loop.manual_loop.optim_step_progress": {
                "total": {"ready": 0, "completed": 0},
                "current": {"ready": 0, "completed": 0},
            },
            "epoch_loop.val_loop.state_dict": {},
            "epoch_loop.val_loop.dataloader_progress": {
                "total": {"ready": 120, "completed": 120},
                "current": {"ready": 1, "completed": 1},
            },
            "epoch_loop.val_loop.epoch_loop.state_dict": {},
            "epoch_loop.val_loop.epoch_loop.batch_progress": {
                "total": {"ready": 0, "completed": 0, "started": 0, "processed": 0},
                "current": {
                    "ready": 36856,
                    "completed": 36856,
                    "started": 36856,
                    "processed": 36856,
                },
                "is_last_batch": True,
            },
            "epoch_progress": {
                "total": {
                    "ready": 121,
                    "completed": 120,
                    "started": 121,
                    "processed": 121,
                },
                "current": {
                    "ready": 121,
                    "completed": 120,
                    "started": 121,
                    "processed": 121,
                },
            },
        },
        "validate_loop": {
            "state_dict": {},
            "dataloader_progress": {
                "total": {"ready": 0, "completed": 0},
                "current": {"ready": 0, "completed": 0},
            },
            "epoch_loop.state_dict": {},
            "epoch_loop.batch_progress": {
                "total": {"ready": 0, "completed": 0, "started": 0, "processed": 0},
                "current": {"ready": 0, "completed": 0, "started": 0, "processed": 0},
                "is_last_batch": False,
            },
        },
        "test_loop": {
            "state_dict": {},
            "dataloader_progress": {
                "total": {"ready": 0, "completed": 0},
                "current": {"ready": 0, "completed": 0},
            },
            "epoch_loop.state_dict": {},
            "epoch_loop.batch_progress": {
                "total": {"ready": 0, "completed": 0, "started": 0, "processed": 0},
                "current": {"ready": 0, "completed": 0, "started": 0, "processed": 0},
                "is_last_batch": False,
            },
        },
        "predict_loop": {
            "state_dict": {},
            "dataloader_progress": {
                "total": {"ready": 0, "completed": 0},
                "current": {"ready": 0, "completed": 0},
            },
            "epoch_loop.state_dict": {},
            "epoch_loop.batch_progress": {
                "total": {"ready": 0, "completed": 0, "started": 0, "processed": 0},
                "current": {"ready": 0, "completed": 0, "started": 0, "processed": 0},
            },
        },
    }

def fix_non_state_dict(ev_ckpt):
    fix_loop(ev_ckpt)

    # Set steps
    ev_ckpt["global_step"] = 5000000

    # Set epoch
    ev_ckpt["epoch"] = 12400

    # LR Schedulers
    ev_ckpt["hyper_parameters"]["config"]["training"]["optimizer"][
        "learning_rate"
    ] = 0.0002
    ev_ckpt["lr_schedulers"][0]["last_epoch"] = ev_ckpt["epoch"]
    ev_ckpt["lr_schedulers"][0]["_step_count"] = ev_ckpt["global_step"]

def convert():
    """
    Convert a jik876 Generator into a EveryVoice Generator.
    """
    ev_ckpt = torch.load(
        EV_TEMPLATE_FILENAME,
        map_location=torch.device("cpu"),
    )
    hfg_ckpt_d = torch.load(
        JIK876_DISCRIMINATOR_FILENAME,
        map_location=torch.device("cpu"),
    )
    hfg_ckpt_g = torch.load(
        JIK876_GENERATOR_FILENAME,
        map_location=torch.device("cpu"),
    )

    state_dict_keys = set(ev_ckpt["state_dict"].keys())
    print(*sorted(state_dict_keys), sep="\n")

    # Set optimizer states
    ev_ckpt["optimizer_states"][0]["state"] = hfg_ckpt_d["optim_g"]["state"]
    ev_ckpt["optimizer_states"][0]["param_groups"] = hfg_ckpt_d["optim_g"][
        "param_groups"
    ]
    ev_ckpt["optimizer_states"][1]["state"] = hfg_ckpt_d["optim_d"]["state"]
    ev_ckpt["optimizer_states"][1]["param_groups"] = hfg_ckpt_d["optim_d"][
        "param_groups"
    ]

    # map generator/discriminator states
    hfg_g_state_map = {
        k: f"generator.{k}" for k in hfg_ckpt_g["generator"].keys()
    }  # key is old key, value is corresponding key in EV
    hfg_msd_state_map = {
        k: f"msd.{k}" for k in hfg_ckpt_d["msd"].keys()
    }  # key is old key, value is corresponding key in EV
    hfg_mpd_state_map = {
        k: f"mpd.{k}" for k in hfg_ckpt_d["mpd"].keys()
    }  # key is old key, value is corresponding key in EV
    msd_pattern = re.compile(r"msd\.discriminators\.[1-9]")
    mpd_pattern = re.compile(r"mpd\.discriminators")
    resblock_pattern = re.compile(r"generator\.resblocks")
    upsample_pattern = re.compile(r"generator\.ups")
    convpre_pattern = re.compile(r"generator\.conv_p")

    hfg_g_state_map = fix_weight_keys(
        hfg_g_state_map, [resblock_pattern, upsample_pattern, convpre_pattern]
    )
    hfg_msd_state_map = fix_weight_keys(hfg_msd_state_map, [msd_pattern])
    hfg_mpd_state_map = fix_weight_keys(hfg_mpd_state_map, [mpd_pattern])

    gen_hfg_states = set(hfg_g_state_map.values())
    msd_hfg_states = set(hfg_msd_state_map.values())
    mpd_hfg_states = set(hfg_mpd_state_map.values())

    gen_ev_states = set(
        [x for x in ev_ckpt["state_dict"].keys() if x.startswith("generator")]
    )
    msd_ev_states = set(
        [x for x in ev_ckpt["state_dict"].keys() if x.startswith("msd")]
    )
    mpd_ev_states = set(
        [x for x in ev_ckpt["state_dict"].keys() if x.startswith("mpd")]
    )

    assert len(gen_ev_states) == len(gen_hfg_states) == 234
    assert len(hfg_msd_state_map) == len(msd_ev_states) == 80
    assert len(hfg_mpd_state_map) == len(mpd_ev_states) == 90

    assert len(gen_ev_states - gen_hfg_states) == 0
    assert len(msd_ev_states - msd_hfg_states) == 0
    assert len(mpd_ev_states - mpd_hfg_states) == 0

    # replace the weights
    for k, v in hfg_g_state_map.items():
        ev_ckpt["state_dict"][v] = hfg_ckpt_g["generator"][k]
        # print(ev_ckpt["state_dict"][v] is hfg_ckpt_g["generator"][k])

    for k, v in hfg_msd_state_map.items():
        ev_ckpt["state_dict"][v] = hfg_ckpt_d["msd"][k]

    for k, v in hfg_mpd_state_map.items():
        ev_ckpt["state_dict"][v] = hfg_ckpt_d["mpd"][k]

    state_dict_keys -= gen_hfg_states | msd_hfg_states | mpd_hfg_states
    # print(state_dict_keys)

    # callbacks - meh? these don't seem like we need to move them
    fix_non_state_dict(ev_ckpt)

    # print(ev_ckpt.keys())
    # print(json.dumps(ev_ckpt["hyper_parameters"], indent=2, ensure_ascii=False))
    torch.save(ev_ckpt, EV_CKPT)

def test_generator():
    """
    Let figure out where our Generator code differs from jik876.
    """
    ev_ckpt = torch.load(EV_CKPT)

    config: dict | HiFiGANConfig = ev_ckpt["hyper_parameters"]["config"]
    if isinstance(config, dict):
        config = HiFiGANConfig(**config)
    ev_generator = HiFiGAN(config)
    ev_generator.load_state_dict(ev_ckpt["state_dict"])
    ev_generator.generator.remove_weight_norm()
    # model.mpd.remove_weight_norm()
    # model.msd.remove_weight_norm()
    # remove_parametrizations(sub_layer, "weight")

    # print(ev_generator)

    with open("hifi_gan/config_v1.json", mode="r") as cin:
        hyperparameters = AttrDict(json.load(cin))
    jik876 = Generator(hyperparameters)
    jik876.load_state_dict(
        torch.load(
            JIK876_GENERATOR_FILENAME,
            map_location=torch.device("cpu"),
        )["generator"]
    )
    # print(jik876)

    input_ = torch.rand((2, 80, 120))
    first = slice(None, 10)

    if False:
        print("num_kernels, num_upsamples")
        print(ev_generator.generator.num_kernels, ev_generator.generator.num_upsamples)
        print(jik876.num_kernels, jik876.num_upsamples)

    if False:
        import torch.nn.functional as F

        print("activation_function, leaky_relu")
        print(config.model.activation_function)

        a = ev_generator.config.model.activation_function(input_)
        print(a)
        b = F.leaky_relu(input_)
        print(b)
        assert torch.all(torch.eq(a, b))

    if False:
        print("conv_pre.bias")
        print(ev_generator.generator.conv_pre.bias[first])
        print(jik876.conv_pre.bias[first])

    if False:
        print("ups[0].bias")
        print(ev_generator.generator.ups[0].bias[first])
        print(jik876.ups[0].bias[first])
        assert torch.all(
            torch.eq(ev_generator.generator.ups[0].bias, jik876.ups[0].bias)
        )

    if False:
        print("conv_pre()")
        ev_generator.generator.conv_pre = jik876.conv_pre
        a = ev_generator.generator.conv_pre(input_)
        print(a, a.shape)
        a = jik876.conv_pre(input_)
        print(a, a.shape)

    if False:
        import torch.nn.functional as F

        num_upsamples = ev_generator.generator.num_upsamples
        num_kernels = ev_generator.generator.num_kernels
        print("Partial")

        a = ev_generator.generator.conv_pre(input_)
        for i in range(num_upsamples):
            a = ev_generator.generator.config.model.activation_function(a)
            a = ev_generator.generator.ups[i](a)
            xs = None
            for j in range(num_kernels):
                if xs is None:
                    xs = ev_generator.generator.resblocks[i * num_kernels + j](a)
                else:
                    xs += ev_generator.generator.resblocks[i * num_kernels + j](a)
                # print(j, xs[:, :2, :4])
            a = xs / num_kernels
            print(i, a[:, :2, :4])

        b = jik876.conv_pre(input_)
        for i in range(num_upsamples):
            b = F.leaky_relu(b, LRELU_SLOPE)
            b = jik876.ups[i](b)
            xs = None
            for j in range(num_kernels):
                if xs is None:
                    xs = jik876.resblocks[i * num_kernels + j](b)
                else:
                    xs += jik876.resblocks[i * num_kernels + j](b)
                # print(j, xs[:, :2, :4])
            b = xs / num_kernels
            print(i, b[:, :2, :4])

        # print(a)
        # print(b)

    if True:
        print("Full run")

        a = ev_generator.generator(input_)
        print(a[:, :2, :4], a.shape)

        b = jik876(input_)
        print(b[:, :2, :4], b.shape)

        print(torch.eq(a, b)[:, :2, :4])
        print((a - b)[:, :2, :4])
        assert torch.allclose(a, b)
        assert torch.all(torch.eq(a, b))

def test_discriminator():
    y = torch.rand((2, 1, 8192))
    y_hat = torch.rand((2, 1, 8192))

    ev_ckpt = torch.load(EV_CKPT)
    config: dict | HiFiGANConfig = ev_ckpt["hyper_parameters"]["config"]
    if isinstance(config, dict):
        config = HiFiGANConfig(**config)
    ev = HiFiGAN(config)
    ev.load_state_dict(ev_ckpt["state_dict"])
    ev.generator.remove_weight_norm()

    hfg_ckpt_d = torch.load(
        JIK876_DISCRIMINATOR_FILENAME,
        map_location=torch.device("cpu"),
    )

    mpd = MultiPeriodDiscriminator()
    mpd.load_state_dict(hfg_ckpt_d["mpd"])
    msd = MultiScaleDiscriminator()
    msd.load_state_dict(hfg_ckpt_d["msd"])

    def equal_helper(a, b):
        """Recursively walk list then check tensor equality."""
        if isinstance(a, (list, tuple)):
            for i, (c, d) in enumerate(zip(a, b)):
                equal_helper(c, d)
        else:
            assert torch.all(torch.eq(a, b))

    equal_helper(ev.mpd(y, y_hat), mpd(y, y_hat))
    equal_helper(ev.msd(y, y_hat), msd(y, y_hat))

if __name__ == "__main__":
    convert()
    test_generator()
    test_discriminator()