Open AlbertHsu0509 opened 2 weeks ago
Thanks for your interest and positive comments. Regarding to your questions:
FAquantizer
class. Could you point me the the part you mentioned?Thank you for your response.
It appears that the quantized output 'z', which has undergone timbre layer normalization, isn't actually used in the subsequent steps. Instead, the wavenet encoder utilizes code[0], code[1], and timbre. (in train_redecoder.py) This encoder first converts code to emb then incorporates the fused_add_tanh_sigmoid_multiply function to apply timbre conditioning, ultimately producing the encoder_out for the decoding phase. (in wavenet.py)
In the original code, timbre layer normalization is applied after the vq2emb step. I'm wondering if the wavenet process replaces this layer normalization. However, I'm not entirely certain about this interpretation. Could you please confirm or correct my understanding? Thank you in advance for your clarification
ok, I understand that you are referring to redecoder instead of the codec itself.
yes, the fused_add_tanh_sigmoid_multiply
has similar behavior to conditional layer norm (which is used as timbre norm)
Hi @Plachtaa,
First off, also want to thank you for this incredible research and this repo.
I'm also hitting the same 'net' key error when I try to finetune the encoder here. If I check the state keys:
dict_keys(['encoder', 'quantizer', 'decoder', 'discriminator', 'fa_predictors'])
When you say load the parameters, do you mean something like modifying the load_checkpoint()
to something like this, which would get rid of the key error. I suspect this is incorrect, because when I swap this logic in I get many KeyErrors in the state_dict. Any chance the checkpoint changed? I'm using the one from huggingface.
def load_checkpoint(model, optimizer, path, load_only_params=True, ignore_modules=[], is_distributed=False):
state = torch.load(path, map_location='cpu')
for key in model:
if key in state and key not in ignore_modules: # Check if the key exists directly in the state
params = state[key]
if not is_distributed:
# Strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
for k in list(params.keys()):
if k.startswith('module.'):
params[k[len("module."):]] = params[k]
del params[k]
print(f'{key} loaded')
model[key].load_state_dict(params, strict=True)
# Set model to evaluation mode
_ = [model[key].eval() for key in model]
if not load_only_params:
epoch = state["epoch"] + 1
iters = state["iters"]
optimizer.load_state_dict(state["optimizer"])
optimizer.load_scheduler_state_dict(state["scheduler"])
else:
epoch = state["epoch"] + 1
iters = state["iters"]
return model, optimizer, epoch, iters
Hi @Plachtaa,
First off, also want to thank you for this incredible research and this repo.
I'm also hitting the same 'net' key error when I try to finetune the encoder here. If I check the state keys:
dict_keys(['encoder', 'quantizer', 'decoder', 'discriminator', 'fa_predictors'])
When you say load the parameters, do you mean something like modifying the
load_checkpoint()
to something like this, which would get rid of the key error. I suspect this is incorrect, because when I swap this logic in I get many KeyErrors in the state_dict. Any chance the checkpoint changed? I'm using the one from huggingface.def load_checkpoint(model, optimizer, path, load_only_params=True, ignore_modules=[], is_distributed=False): state = torch.load(path, map_location='cpu') for key in model: if key in state and key not in ignore_modules: # Check if the key exists directly in the state params = state[key] if not is_distributed: # Strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix for k in list(params.keys()): if k.startswith('module.'): params[k[len("module."):]] = params[k] del params[k] print(f'{key} loaded') model[key].load_state_dict(params, strict=True) # Set model to evaluation mode _ = [model[key].eval() for key in model] if not load_only_params: epoch = state["epoch"] + 1 iters = state["iters"] optimizer.load_state_dict(state["optimizer"]) optimizer.load_scheduler_state_dict(state["scheduler"]) else: epoch = state["epoch"] + 1 iters = state["iters"] return model, optimizer, epoch, iters
Please kindly share the error message. If you cannot find the keys that mismatch, use strict=False
and check the returned value for which keys are missing or redundant
Please kindly share the error message. If you cannot find the keys that mismatch, use
strict=False
and check the returned value for which keys are missing or redundant
train.py unmodified
[rank0]: File "/audio/models/FAcodec/train.py", line 494, in <module>
[rank0]: main(args)
[rank0]: File "/audio/models/FAcodec/train.py", line 147, in main
[rank0]: model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, latest_checkpoint,
[rank0]: File "/audio/models/FAcodec/modules/commons.py", line 448, in load_checkpoint
[rank0]: params = state['net']
[rank0]: KeyError: 'net'
train.py - using load_checkpoint
from above
[rank0]: File "/audio/models/FAcodec/train.py", line 494, in <module>
[rank0]: main(args)
[rank0]: File "/audio/models/FAcodec/train.py", line 147, in main
[rank0]: model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, latest_checkpoint,
[rank0]: File "/audio/models/FAcodec/modules/commons.py", line 488, in load_checkpoint
[rank0]: model[key].load_state_dict(params, strict=True)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2215, in load_state_dict
[rank0]: raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
[rank0]: RuntimeError: Error(s) in loading state_dict for DistributedDataParallel:
[rank0]: Missing key(s) in state_dict: "module.block.0.conv.conv.bias", "module.block.0.conv.conv.weight_g", "module.block.0.conv.conv.weight_v", "module.block.1.block.0.block.0.alpha", "module.block.1.block.0.block.1.conv.conv.bias", "module.block.1.block.0.block.1.conv.conv.weight_g", "module.block.1.block.0.block.1.conv.conv.weight_v", "module.block.1.block.0.block.2.alpha", "module.block.1.block.0.block.3.conv.conv.bias", "module.block.1.block.0.block.3.conv.conv.weight_g", "module.block.1.block.0.block.3.conv.conv.weight_v", "module.block.1.block.1.block.0.alpha", "module.block.1.block.1.block.1.conv.conv.bias", "module.block.1.block.1.block.1.conv.conv.weight_g", "module.block.1.block.1.block.1.conv.conv.weight_v", "module.block.1.block.1.block.2.alpha", "module.block.1.block.1.block.3.conv.conv.bias", "module.block.1.block.1.block.3.conv.conv.weight_g", "module.block.1.block.1.block.3.conv.conv.weight_v", "module.block.1.block.2.block.0.alpha", "module.block.1.block.2.block.1.conv.conv.bias", "module.block.1.block.2.block.1.conv.conv.weight_g", "module.block.1.block.2.block.1.conv.conv.weight_v", "module.block.1.block.2.block.2.alpha", "module.block.1.block.2.block.3.conv.conv.bias", "module.block.1.block.2.block.3.conv.conv.weight_g", "module.block.1.block.2.block.3.conv.conv.weight_v", "module.block.1.block.3.alpha", "module.block.1.block.4.conv.conv.bias", "module.block.1.block.4.conv.conv.weight_g", "module.block.1.block.4.conv.conv.weight_v", "module.block.2.block.0.block.0.alpha", "module.block.2.block.0.block.1.conv.conv.bias", "module.block.2.block.0.block.1.conv.conv.weight_g", "module.block.2.block.0.block.1.conv.conv.weight_v", "module.block.2.block.0.block.2.alpha", "module.block.2.block.0.block.3.conv.conv.bias", "module.block.2.block.0.block.3.conv.conv.weight_g", "module.block.2.block.0.block.3.conv.conv.weight_v", "module.block.2.block.1.block.0.alpha", "module.block.2.block.1.block.1.conv.conv.bias", "module.block.2.block.1.block.1.conv.conv.weight_g", "module.block.2.block.1.block.1.conv.conv.weight_v", "module.block.2.block.1.block.2.alpha", "module.block.2.block.1.block.3.conv.conv.bias", "module.block.2.block.1.block.3.conv.conv.weight_g", "module.block.2.block.1.block.3.conv.conv.weight_v", "module.block.2.block.2.block.0.alpha", "module.block.2.block.2.block.1.conv.conv.bias", "module.block.2.block.2.block.1.conv.conv.weight_g", "module.block.2.block.2.block.1.conv.conv.weight_v", "module.block.2.block.2.block.2.alpha", "module.block.2.block.2.block.3.conv.conv.bias", "module.block.2.block.2.block.3.conv.conv.weight_g", "module.block.2.block.2.block.3.conv.conv.weight_v", "module.block.2.block.3.alpha", "module.block.2.block.4.conv.conv.bias", "module.block.2.block.4.conv.conv.weight_g", "module.block.2.block.4.conv.conv.weight_v", "module.block.3.block.0.block.0.alpha", "module.block.3.block.0.block.1.conv.conv.bias", "module.block.3.block.0.block.1.conv.conv.weight_g", "module.block.3.block.0.block.1.conv.conv.weight_v", "module.block.3.block.0.block.2.alpha", "module.block.3.block.0.block.3.conv.conv.bias", "module.block.3.block.0.block.3.conv.conv.weight_g", "module.block.3.block.0.block.3.conv.conv.weight_v", "module.block.3.block.1.block.0.alpha", "module.block.3.block.1.block.1.conv.conv.bias", "module.block.3.block.1.block.1.conv.conv.weight_g", "module.block.3.block.1.block.1.conv.conv.weight_v", "module.block.3.block.1.block.2.alpha", "module.block.3.block.1.block.3.conv.conv.bias", "module.block.3.block.1.block.3.conv.conv.weight_g", "module.block.3.block.1.block.3.conv.conv.weight_v", "module.block.3.block.2.block.0.alpha", "module.block.3.block.2.block.1.conv.conv.bias", "module.block.3.block.2.block.1.conv.conv.weight_g", "module.block.3.block.2.block.1.conv.conv.weight_v", "module.block.3.block.2.block.2.alpha", "module.block.3.block.2.block.3.conv.conv.bias", "module.block.3.block.2.block.3.conv.conv.weight_g", "module.block.3.block.2.block.3.conv.conv.weight_v", "module.block.3.block.3.alpha", "module.block.3.block.4.conv.conv.bias", "module.block.3.block.4.conv.conv.weight_g", "module.block.3.block.4.conv.conv.weight_v", "module.block.4.block.0.block.0.alpha", "module.block.4.block.0.block.1.conv.conv.bias", "module.block.4.block.0.block.1.conv.conv.weight_g", "module.block.4.block.0.block.1.conv.conv.weight_v", "module.block.4.block.0.block.2.alpha", "module.block.4.block.0.block.3.conv.conv.bias", "module.block.4.block.0.block.3.conv.conv.weight_g", "module.block.4.block.0.block.3.conv.conv.weight_v", "module.block.4.block.1.block.0.alpha", "module.block.4.block.1.block.1.conv.conv.bias", "module.block.4.block.1.block.1.conv.conv.weight_g", "module.block.4.block.1.block.1.conv.conv.weight_v", "module.block.4.block.1.block.2.alpha", "module.block.4.block.1.block.3.conv.conv.bias", "module.block.4.block.1.block.3.conv.conv.weight_g", "module.block.4.block.1.block.3.conv.conv.weight_v", "module.block.4.block.2.block.0.alpha", "module.block.4.block.2.block.1.conv.conv.bias", "module.block.4.block.2.block.1.conv.conv.weight_g", "module.block.4.block.2.block.1.conv.conv.weight_v", "module.block.4.block.2.block.2.alpha", "module.block.4.block.2.block.3.conv.conv.bias", "module.block.4.block.2.block.3.conv.conv.weight_g", "module.block.4.block.2.block.3.conv.conv.weight_v", "module.block.4.block.3.alpha", "module.block.4.block.4.conv.conv.bias", "module.block.4.block.4.conv.conv.weight_g", "module.block.4.block.4.conv.conv.weight_v", "module.block.5.lstm.weight_ih_l0", "module.block.5.lstm.weight_hh_l0", "module.block.5.lstm.bias_ih_l0", "module.block.5.lstm.bias_hh_l0", "module.block.5.lstm.weight_ih_l1", "module.block.5.lstm.weight_hh_l1", "module.block.5.lstm.bias_ih_l1", "module.block.5.lstm.bias_hh_l1", "module.block.6.alpha", "module.block.7.conv.conv.bias", "module.block.7.conv.conv.weight_g", "module.block.7.conv.conv.weight_v".
[rank0]: Unexpected key(s) in state_dict: "block.0.conv.conv.bias", "block.0.conv.conv.weight_g", "block.0.conv.conv.weight_v", "block.1.block.0.block.0.alpha", "block.1.block.0.block.1.conv.conv.bias", "block.1.block.0.block.1.conv.conv.weight_g", "block.1.block.0.block.1.conv.conv.weight_v", "block.1.block.0.block.2.alpha", "block.1.block.0.block.3.conv.conv.bias", "block.1.block.0.block.3.conv.conv.weight_g", "block.1.block.0.block.3.conv.conv.weight_v", "block.1.block.1.block.0.alpha", "block.1.block.1.block.1.conv.conv.bias", "block.1.block.1.block.1.conv.conv.weight_g", "block.1.block.1.block.1.conv.conv.weight_v", "block.1.block.1.block.2.alpha", "block.1.block.1.block.3.conv.conv.bias", "block.1.block.1.block.3.conv.conv.weight_g", "block.1.block.1.block.3.conv.conv.weight_v", "block.1.block.2.block.0.alpha", "block.1.block.2.block.1.conv.conv.bias", "block.1.block.2.block.1.conv.conv.weight_g", "block.1.block.2.block.1.conv.conv.weight_v", "block.1.block.2.block.2.alpha", "block.1.block.2.block.3.conv.conv.bias", "block.1.block.2.block.3.conv.conv.weight_g", "block.1.block.2.block.3.conv.conv.weight_v", "block.1.block.3.alpha", "block.1.block.4.conv.conv.bias", "block.1.block.4.conv.conv.weight_g", "block.1.block.4.conv.conv.weight_v", "block.2.block.0.block.0.alpha", "block.2.block.0.block.1.conv.conv.bias", "block.2.block.0.block.1.conv.conv.weight_g", "block.2.block.0.block.1.conv.conv.weight_v", "block.2.block.0.block.2.alpha", "block.2.block.0.block.3.conv.conv.bias", "block.2.block.0.block.3.conv.conv.weight_g", "block.2.block.0.block.3.conv.conv.weight_v", "block.2.block.1.block.0.alpha", "block.2.block.1.block.1.conv.conv.bias", "block.2.block.1.block.1.conv.conv.weight_g", "block.2.block.1.block.1.conv.conv.weight_v", "block.2.block.1.block.2.alpha", "block.2.block.1.block.3.conv.conv.bias", "block.2.block.1.block.3.conv.conv.weight_g", "block.2.block.1.block.3.conv.conv.weight_v", "block.2.block.2.block.0.alpha", "block.2.block.2.block.1.conv.conv.bias", "block.2.block.2.block.1.conv.conv.weight_g", "block.2.block.2.block.1.conv.conv.weight_v", "block.2.block.2.block.2.alpha", "block.2.block.2.block.3.conv.conv.bias", "block.2.block.2.block.3.conv.conv.weight_g", "block.2.block.2.block.3.conv.conv.weight_v", "block.2.block.3.alpha", "block.2.block.4.conv.conv.bias", "block.2.block.4.conv.conv.weight_g", "block.2.block.4.conv.conv.weight_v", "block.3.block.0.block.0.alpha", "block.3.block.0.block.1.conv.conv.bias", "block.3.block.0.block.1.conv.conv.weight_g", "block.3.block.0.block.1.conv.conv.weight_v", "block.3.block.0.block.2.alpha", "block.3.block.0.block.3.conv.conv.bias", "block.3.block.0.block.3.conv.conv.weight_g", "block.3.block.0.block.3.conv.conv.weight_v", "block.3.block.1.block.0.alpha", "block.3.block.1.block.1.conv.conv.bias", "block.3.block.1.block.1.conv.conv.weight_g", "block.3.block.1.block.1.conv.conv.weight_v", "block.3.block.1.block.2.alpha", "block.3.block.1.block.3.conv.conv.bias", "block.3.block.1.block.3.conv.conv.weight_g", "block.3.block.1.block.3.conv.conv.weight_v", "block.3.block.2.block.0.alpha", "block.3.block.2.block.1.conv.conv.bias", "block.3.block.2.block.1.conv.conv.weight_g", "block.3.block.2.block.1.conv.conv.weight_v", "block.3.block.2.block.2.alpha", "block.3.block.2.block.3.conv.conv.bias", "block.3.block.2.block.3.conv.conv.weight_g", "block.3.block.2.block.3.conv.conv.weight_v", "block.3.block.3.alpha", "block.3.block.4.conv.conv.bias", "block.3.block.4.conv.conv.weight_g", "block.3.block.4.conv.conv.weight_v", "block.4.block.0.block.0.alpha", "block.4.block.0.block.1.conv.conv.bias", "block.4.block.0.block.1.conv.conv.weight_g", "block.4.block.0.block.1.conv.conv.weight_v", "block.4.block.0.block.2.alpha", "block.4.block.0.block.3.conv.conv.bias", "block.4.block.0.block.3.conv.conv.weight_g", "block.4.block.0.block.3.conv.conv.weight_v", "block.4.block.1.block.0.alpha", "block.4.block.1.block.1.conv.conv.bias", "block.4.block.1.block.1.conv.conv.weight_g", "block.4.block.1.block.1.conv.conv.weight_v", "block.4.block.1.block.2.alpha", "block.4.block.1.block.3.conv.conv.bias", "block.4.block.1.block.3.conv.conv.weight_g", "block.4.block.1.block.3.conv.conv.weight_v", "block.4.block.2.block.0.alpha", "block.4.block.2.block.1.conv.conv.bias", "block.4.block.2.block.1.conv.conv.weight_g", "block.4.block.2.block.1.conv.conv.weight_v", "block.4.block.2.block.2.alpha", "block.4.block.2.block.3.conv.conv.bias", "block.4.block.2.block.3.conv.conv.weight_g", "block.4.block.2.block.3.conv.conv.weight_v", "block.4.block.3.alpha", "block.4.block.4.conv.conv.bias", "block.4.block.4.conv.conv.weight_g", "block.4.block.4.conv.conv.weight_v", "block.5.lstm.weight_ih_l0", "block.5.lstm.weight_hh_l0", "block.5.lstm.bias_ih_l0", "block.5.lstm.bias_hh_l0", "block.5.lstm.weight_ih_l1", "block.5.lstm.weight_hh_l1", "block.5.lstm.bias_ih_l1", "block.5.lstm.bias_hh_l1", "block.6.alpha", "block.7.conv.conv.bias", "block.7.conv.conv.weight_g", "block.7.conv.conv.weight_v".
train_recoder.py - unmodified
[rank0]: Traceback (most recent call last):
[rank0]: File "/audio/models/FAcodec/train.py", line 494, in <module>
[rank0]: main(args)
[rank0]: File "/audio/models/FAcodec/train.py", line 147, in main
[rank0]: model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, latest_checkpoint,
[rank0]: File "/audio/models/FAcodec/modules/commons.py", line 448, in load_checkpoint
[rank0]: params = state['net']
[rank0]: KeyError: 'net'
train_recoder.py - using load_checkpoint
from above
[rank0]: Traceback (most recent call last):
[rank0]: File "/audio/models/FAcodec/train_redecoder.py", line 456, in <module>
[rank0]: main(args)
[rank0]: File "/audio/models/FAcodec/train_redecoder.py", line 94, in main
[rank0]: codec_encoder, _, _, _ = load_checkpoint(codec_encoder, None, config['pretrained_encoder'],
[rank0]: File "/audio/models/FAcodec/modules/commons.py", line 500, in load_checkpoint
[rank0]: epoch = state["epoch"] + 1
[rank0]: KeyError: 'epoch'
Please kindly share the error message. If you cannot find the keys that mismatch, use
strict=False
and check the returned value for which keys are missing or redundanttrain.py unmodified
[rank0]: File "/audio/models/FAcodec/train.py", line 494, in <module> [rank0]: main(args) [rank0]: File "/audio/models/FAcodec/train.py", line 147, in main [rank0]: model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, latest_checkpoint, [rank0]: File "/audio/models/FAcodec/modules/commons.py", line 448, in load_checkpoint [rank0]: params = state['net'] [rank0]: KeyError: 'net'
train.py - using
load_checkpoint
from above[rank0]: File "/audio/models/FAcodec/train.py", line 494, in <module> [rank0]: main(args) [rank0]: File "/audio/models/FAcodec/train.py", line 147, in main [rank0]: model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, latest_checkpoint, [rank0]: File "/audio/models/FAcodec/modules/commons.py", line 488, in load_checkpoint [rank0]: model[key].load_state_dict(params, strict=True) [rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2215, in load_state_dict [rank0]: raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( [rank0]: RuntimeError: Error(s) in loading state_dict for DistributedDataParallel: [rank0]: Missing key(s) in state_dict: "module.block.0.conv.conv.bias", "module.block.0.conv.conv.weight_g", "module.block.0.conv.conv.weight_v", "module.block.1.block.0.block.0.alpha", "module.block.1.block.0.block.1.conv.conv.bias", "module.block.1.block.0.block.1.conv.conv.weight_g", "module.block.1.block.0.block.1.conv.conv.weight_v", "module.block.1.block.0.block.2.alpha", "module.block.1.block.0.block.3.conv.conv.bias", "module.block.1.block.0.block.3.conv.conv.weight_g", "module.block.1.block.0.block.3.conv.conv.weight_v", "module.block.1.block.1.block.0.alpha", "module.block.1.block.1.block.1.conv.conv.bias", "module.block.1.block.1.block.1.conv.conv.weight_g", "module.block.1.block.1.block.1.conv.conv.weight_v", "module.block.1.block.1.block.2.alpha", "module.block.1.block.1.block.3.conv.conv.bias", "module.block.1.block.1.block.3.conv.conv.weight_g", "module.block.1.block.1.block.3.conv.conv.weight_v", "module.block.1.block.2.block.0.alpha", "module.block.1.block.2.block.1.conv.conv.bias", "module.block.1.block.2.block.1.conv.conv.weight_g", "module.block.1.block.2.block.1.conv.conv.weight_v", "module.block.1.block.2.block.2.alpha", "module.block.1.block.2.block.3.conv.conv.bias", "module.block.1.block.2.block.3.conv.conv.weight_g", "module.block.1.block.2.block.3.conv.conv.weight_v", "module.block.1.block.3.alpha", "module.block.1.block.4.conv.conv.bias", "module.block.1.block.4.conv.conv.weight_g", "module.block.1.block.4.conv.conv.weight_v", "module.block.2.block.0.block.0.alpha", "module.block.2.block.0.block.1.conv.conv.bias", "module.block.2.block.0.block.1.conv.conv.weight_g", "module.block.2.block.0.block.1.conv.conv.weight_v", "module.block.2.block.0.block.2.alpha", "module.block.2.block.0.block.3.conv.conv.bias", "module.block.2.block.0.block.3.conv.conv.weight_g", "module.block.2.block.0.block.3.conv.conv.weight_v", "module.block.2.block.1.block.0.alpha", "module.block.2.block.1.block.1.conv.conv.bias", "module.block.2.block.1.block.1.conv.conv.weight_g", "module.block.2.block.1.block.1.conv.conv.weight_v", "module.block.2.block.1.block.2.alpha", "module.block.2.block.1.block.3.conv.conv.bias", "module.block.2.block.1.block.3.conv.conv.weight_g", "module.block.2.block.1.block.3.conv.conv.weight_v", "module.block.2.block.2.block.0.alpha", "module.block.2.block.2.block.1.conv.conv.bias", "module.block.2.block.2.block.1.conv.conv.weight_g", "module.block.2.block.2.block.1.conv.conv.weight_v", "module.block.2.block.2.block.2.alpha", "module.block.2.block.2.block.3.conv.conv.bias", "module.block.2.block.2.block.3.conv.conv.weight_g", "module.block.2.block.2.block.3.conv.conv.weight_v", "module.block.2.block.3.alpha", "module.block.2.block.4.conv.conv.bias", "module.block.2.block.4.conv.conv.weight_g", "module.block.2.block.4.conv.conv.weight_v", "module.block.3.block.0.block.0.alpha", "module.block.3.block.0.block.1.conv.conv.bias", "module.block.3.block.0.block.1.conv.conv.weight_g", "module.block.3.block.0.block.1.conv.conv.weight_v", "module.block.3.block.0.block.2.alpha", "module.block.3.block.0.block.3.conv.conv.bias", "module.block.3.block.0.block.3.conv.conv.weight_g", "module.block.3.block.0.block.3.conv.conv.weight_v", "module.block.3.block.1.block.0.alpha", "module.block.3.block.1.block.1.conv.conv.bias", "module.block.3.block.1.block.1.conv.conv.weight_g", "module.block.3.block.1.block.1.conv.conv.weight_v", "module.block.3.block.1.block.2.alpha", "module.block.3.block.1.block.3.conv.conv.bias", "module.block.3.block.1.block.3.conv.conv.weight_g", "module.block.3.block.1.block.3.conv.conv.weight_v", "module.block.3.block.2.block.0.alpha", "module.block.3.block.2.block.1.conv.conv.bias", "module.block.3.block.2.block.1.conv.conv.weight_g", "module.block.3.block.2.block.1.conv.conv.weight_v", "module.block.3.block.2.block.2.alpha", "module.block.3.block.2.block.3.conv.conv.bias", "module.block.3.block.2.block.3.conv.conv.weight_g", "module.block.3.block.2.block.3.conv.conv.weight_v", "module.block.3.block.3.alpha", "module.block.3.block.4.conv.conv.bias", "module.block.3.block.4.conv.conv.weight_g", "module.block.3.block.4.conv.conv.weight_v", "module.block.4.block.0.block.0.alpha", "module.block.4.block.0.block.1.conv.conv.bias", "module.block.4.block.0.block.1.conv.conv.weight_g", "module.block.4.block.0.block.1.conv.conv.weight_v", "module.block.4.block.0.block.2.alpha", "module.block.4.block.0.block.3.conv.conv.bias", "module.block.4.block.0.block.3.conv.conv.weight_g", "module.block.4.block.0.block.3.conv.conv.weight_v", "module.block.4.block.1.block.0.alpha", "module.block.4.block.1.block.1.conv.conv.bias", "module.block.4.block.1.block.1.conv.conv.weight_g", "module.block.4.block.1.block.1.conv.conv.weight_v", "module.block.4.block.1.block.2.alpha", "module.block.4.block.1.block.3.conv.conv.bias", "module.block.4.block.1.block.3.conv.conv.weight_g", "module.block.4.block.1.block.3.conv.conv.weight_v", "module.block.4.block.2.block.0.alpha", "module.block.4.block.2.block.1.conv.conv.bias", "module.block.4.block.2.block.1.conv.conv.weight_g", "module.block.4.block.2.block.1.conv.conv.weight_v", "module.block.4.block.2.block.2.alpha", "module.block.4.block.2.block.3.conv.conv.bias", "module.block.4.block.2.block.3.conv.conv.weight_g", "module.block.4.block.2.block.3.conv.conv.weight_v", "module.block.4.block.3.alpha", "module.block.4.block.4.conv.conv.bias", "module.block.4.block.4.conv.conv.weight_g", "module.block.4.block.4.conv.conv.weight_v", "module.block.5.lstm.weight_ih_l0", "module.block.5.lstm.weight_hh_l0", "module.block.5.lstm.bias_ih_l0", "module.block.5.lstm.bias_hh_l0", "module.block.5.lstm.weight_ih_l1", "module.block.5.lstm.weight_hh_l1", "module.block.5.lstm.bias_ih_l1", "module.block.5.lstm.bias_hh_l1", "module.block.6.alpha", "module.block.7.conv.conv.bias", "module.block.7.conv.conv.weight_g", "module.block.7.conv.conv.weight_v". [rank0]: Unexpected key(s) in state_dict: "block.0.conv.conv.bias", "block.0.conv.conv.weight_g", "block.0.conv.conv.weight_v", "block.1.block.0.block.0.alpha", "block.1.block.0.block.1.conv.conv.bias", "block.1.block.0.block.1.conv.conv.weight_g", "block.1.block.0.block.1.conv.conv.weight_v", "block.1.block.0.block.2.alpha", "block.1.block.0.block.3.conv.conv.bias", "block.1.block.0.block.3.conv.conv.weight_g", "block.1.block.0.block.3.conv.conv.weight_v", "block.1.block.1.block.0.alpha", "block.1.block.1.block.1.conv.conv.bias", "block.1.block.1.block.1.conv.conv.weight_g", "block.1.block.1.block.1.conv.conv.weight_v", "block.1.block.1.block.2.alpha", "block.1.block.1.block.3.conv.conv.bias", "block.1.block.1.block.3.conv.conv.weight_g", "block.1.block.1.block.3.conv.conv.weight_v", "block.1.block.2.block.0.alpha", "block.1.block.2.block.1.conv.conv.bias", "block.1.block.2.block.1.conv.conv.weight_g", "block.1.block.2.block.1.conv.conv.weight_v", "block.1.block.2.block.2.alpha", "block.1.block.2.block.3.conv.conv.bias", "block.1.block.2.block.3.conv.conv.weight_g", "block.1.block.2.block.3.conv.conv.weight_v", "block.1.block.3.alpha", "block.1.block.4.conv.conv.bias", "block.1.block.4.conv.conv.weight_g", "block.1.block.4.conv.conv.weight_v", "block.2.block.0.block.0.alpha", "block.2.block.0.block.1.conv.conv.bias", "block.2.block.0.block.1.conv.conv.weight_g", "block.2.block.0.block.1.conv.conv.weight_v", "block.2.block.0.block.2.alpha", "block.2.block.0.block.3.conv.conv.bias", "block.2.block.0.block.3.conv.conv.weight_g", "block.2.block.0.block.3.conv.conv.weight_v", "block.2.block.1.block.0.alpha", "block.2.block.1.block.1.conv.conv.bias", "block.2.block.1.block.1.conv.conv.weight_g", "block.2.block.1.block.1.conv.conv.weight_v", "block.2.block.1.block.2.alpha", "block.2.block.1.block.3.conv.conv.bias", "block.2.block.1.block.3.conv.conv.weight_g", "block.2.block.1.block.3.conv.conv.weight_v", "block.2.block.2.block.0.alpha", "block.2.block.2.block.1.conv.conv.bias", "block.2.block.2.block.1.conv.conv.weight_g", "block.2.block.2.block.1.conv.conv.weight_v", "block.2.block.2.block.2.alpha", "block.2.block.2.block.3.conv.conv.bias", "block.2.block.2.block.3.conv.conv.weight_g", "block.2.block.2.block.3.conv.conv.weight_v", "block.2.block.3.alpha", "block.2.block.4.conv.conv.bias", "block.2.block.4.conv.conv.weight_g", "block.2.block.4.conv.conv.weight_v", "block.3.block.0.block.0.alpha", "block.3.block.0.block.1.conv.conv.bias", "block.3.block.0.block.1.conv.conv.weight_g", "block.3.block.0.block.1.conv.conv.weight_v", "block.3.block.0.block.2.alpha", "block.3.block.0.block.3.conv.conv.bias", "block.3.block.0.block.3.conv.conv.weight_g", "block.3.block.0.block.3.conv.conv.weight_v", "block.3.block.1.block.0.alpha", "block.3.block.1.block.1.conv.conv.bias", "block.3.block.1.block.1.conv.conv.weight_g", "block.3.block.1.block.1.conv.conv.weight_v", "block.3.block.1.block.2.alpha", "block.3.block.1.block.3.conv.conv.bias", "block.3.block.1.block.3.conv.conv.weight_g", "block.3.block.1.block.3.conv.conv.weight_v", "block.3.block.2.block.0.alpha", "block.3.block.2.block.1.conv.conv.bias", "block.3.block.2.block.1.conv.conv.weight_g", "block.3.block.2.block.1.conv.conv.weight_v", "block.3.block.2.block.2.alpha", "block.3.block.2.block.3.conv.conv.bias", "block.3.block.2.block.3.conv.conv.weight_g", "block.3.block.2.block.3.conv.conv.weight_v", "block.3.block.3.alpha", "block.3.block.4.conv.conv.bias", "block.3.block.4.conv.conv.weight_g", "block.3.block.4.conv.conv.weight_v", "block.4.block.0.block.0.alpha", "block.4.block.0.block.1.conv.conv.bias", "block.4.block.0.block.1.conv.conv.weight_g", "block.4.block.0.block.1.conv.conv.weight_v", "block.4.block.0.block.2.alpha", "block.4.block.0.block.3.conv.conv.bias", "block.4.block.0.block.3.conv.conv.weight_g", "block.4.block.0.block.3.conv.conv.weight_v", "block.4.block.1.block.0.alpha", "block.4.block.1.block.1.conv.conv.bias", "block.4.block.1.block.1.conv.conv.weight_g", "block.4.block.1.block.1.conv.conv.weight_v", "block.4.block.1.block.2.alpha", "block.4.block.1.block.3.conv.conv.bias", "block.4.block.1.block.3.conv.conv.weight_g", "block.4.block.1.block.3.conv.conv.weight_v", "block.4.block.2.block.0.alpha", "block.4.block.2.block.1.conv.conv.bias", "block.4.block.2.block.1.conv.conv.weight_g", "block.4.block.2.block.1.conv.conv.weight_v", "block.4.block.2.block.2.alpha", "block.4.block.2.block.3.conv.conv.bias", "block.4.block.2.block.3.conv.conv.weight_g", "block.4.block.2.block.3.conv.conv.weight_v", "block.4.block.3.alpha", "block.4.block.4.conv.conv.bias", "block.4.block.4.conv.conv.weight_g", "block.4.block.4.conv.conv.weight_v", "block.5.lstm.weight_ih_l0", "block.5.lstm.weight_hh_l0", "block.5.lstm.bias_ih_l0", "block.5.lstm.bias_hh_l0", "block.5.lstm.weight_ih_l1", "block.5.lstm.weight_hh_l1", "block.5.lstm.bias_ih_l1", "block.5.lstm.bias_hh_l1", "block.6.alpha", "block.7.conv.conv.bias", "block.7.conv.conv.weight_g", "block.7.conv.conv.weight_v".
train_recoder.py - unmodified
[rank0]: Traceback (most recent call last): [rank0]: File "/audio/models/FAcodec/train.py", line 494, in <module> [rank0]: main(args) [rank0]: File "/audio/models/FAcodec/train.py", line 147, in main [rank0]: model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, latest_checkpoint, [rank0]: File "/audio/models/FAcodec/modules/commons.py", line 448, in load_checkpoint [rank0]: params = state['net'] [rank0]: KeyError: 'net'
train_recoder.py - using
load_checkpoint
from above[rank0]: Traceback (most recent call last): [rank0]: File "/audio/models/FAcodec/train_redecoder.py", line 456, in <module> [rank0]: main(args) [rank0]: File "/audio/models/FAcodec/train_redecoder.py", line 94, in main [rank0]: codec_encoder, _, _, _ = load_checkpoint(codec_encoder, None, config['pretrained_encoder'], [rank0]: File "/audio/models/FAcodec/modules/commons.py", line 500, in load_checkpoint [rank0]: epoch = state["epoch"] + 1 [rank0]: KeyError: 'epoch'
this loading script only work for checkpoints that is saved during training, it does not work for the bin
file I published on HF.
To load parameters from that pretrained model to your own DDP training, you should:
accelerator.prepare()
so that there won't be module
prefix in parameter namesload_only_params
to True because there is no epoch, iters or optimizer states info in released checkpoint
Hi, thank you for sharing the training code of FACodec! I've come across a couple of points: 1.Fine-tuning the redecoder: I'm interested in fine-tuning the redecoder using the provided encoder and redecoder bin files. However, I noticed that there's no 'net' key in the bin file, which seems to cause an issue when loading the checkpoint. Could you provide some guidance on how to properly load these files for fine-tuning? 2.Additional activation function: I noticed that there's an additional WN gated activation function applied after the timbre layer norm, which differs from the original code and description in the paper. I'm curious about the reasoning behind this architectural change. Could you share some insights into why this modification was made and how it impacts the model's performance?