Plachtaa / FAcodec

Training code for FAcodec presented in NaturalSpeech3
155 stars 16 forks source link

pytorch_model.bin key error and WN activated function #21

Open AlbertHsu0509 opened 2 weeks ago

AlbertHsu0509 commented 2 weeks ago

Hi, thank you for sharing the training code of FACodec! I've come across a couple of points: 1.Fine-tuning the redecoder: I'm interested in fine-tuning the redecoder using the provided encoder and redecoder bin files. However, I noticed that there's no 'net' key in the bin file, which seems to cause an issue when loading the checkpoint. Could you provide some guidance on how to properly load these files for fine-tuning? 2.Additional activation function: I noticed that there's an additional WN gated activation function applied after the timbre layer norm, which differs from the original code and description in the paper. I'm curious about the reasoning behind this architectural change. Could you share some insights into why this modification was made and how it impacts the model's performance?

Plachtaa commented 1 week ago

Thanks for your interest and positive comments. Regarding to your questions:

  1. the bin files uploaded is the corresponding item of the 'net' key while other keys corresponding to optimizer states has been discarded. If you would like to fine tune, just load the parameters
  2. I don't remember there is an additional activation function after timbre layer norm of the FAquantizer class. Could you point me the the part you mentioned?
AlbertHsu0509 commented 1 week ago

Thank you for your response.

It appears that the quantized output 'z', which has undergone timbre layer normalization, isn't actually used in the subsequent steps. Instead, the wavenet encoder utilizes code[0], code[1], and timbre. (in train_redecoder.py) image This encoder first converts code to emb then incorporates the fused_add_tanh_sigmoid_multiply function to apply timbre conditioning, ultimately producing the encoder_out for the decoding phase. (in wavenet.py) image

In the original code, timbre layer normalization is applied after the vq2emb step. I'm wondering if the wavenet process replaces this layer normalization. However, I'm not entirely certain about this interpretation. Could you please confirm or correct my understanding? Thank you in advance for your clarification

Plachtaa commented 1 week ago

ok, I understand that you are referring to redecoder instead of the codec itself. yes, the fused_add_tanh_sigmoid_multiply has similar behavior to conditional layer norm (which is used as timbre norm)

rgxb2807 commented 1 week ago

Hi @Plachtaa,

First off, also want to thank you for this incredible research and this repo.

I'm also hitting the same 'net' key error when I try to finetune the encoder here. If I check the state keys:

dict_keys(['encoder', 'quantizer', 'decoder', 'discriminator', 'fa_predictors'])

When you say load the parameters, do you mean something like modifying the load_checkpoint() to something like this, which would get rid of the key error. I suspect this is incorrect, because when I swap this logic in I get many KeyErrors in the state_dict. Any chance the checkpoint changed? I'm using the one from huggingface.

def load_checkpoint(model, optimizer, path, load_only_params=True, ignore_modules=[], is_distributed=False):
    state = torch.load(path, map_location='cpu')
    for key in model:
        if key in state and key not in ignore_modules:  # Check if the key exists directly in the state
            params = state[key]

            if not is_distributed:
                # Strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
                for k in list(params.keys()):
                    if k.startswith('module.'):
                        params[k[len("module."):]] = params[k]
                        del params[k]

            print(f'{key} loaded')
            model[key].load_state_dict(params, strict=True)

    # Set model to evaluation mode
    _ = [model[key].eval() for key in model]

    if not load_only_params:
      epoch = state["epoch"] + 1
      iters = state["iters"]
      optimizer.load_state_dict(state["optimizer"])
      optimizer.load_scheduler_state_dict(state["scheduler"])

    else:
      epoch = state["epoch"] + 1
      iters = state["iters"]

    return model, optimizer, epoch, iters
Plachtaa commented 1 week ago

Hi @Plachtaa,

First off, also want to thank you for this incredible research and this repo.

I'm also hitting the same 'net' key error when I try to finetune the encoder here. If I check the state keys:

dict_keys(['encoder', 'quantizer', 'decoder', 'discriminator', 'fa_predictors'])

When you say load the parameters, do you mean something like modifying the load_checkpoint() to something like this, which would get rid of the key error. I suspect this is incorrect, because when I swap this logic in I get many KeyErrors in the state_dict. Any chance the checkpoint changed? I'm using the one from huggingface.

def load_checkpoint(model, optimizer, path, load_only_params=True, ignore_modules=[], is_distributed=False):
    state = torch.load(path, map_location='cpu')
    for key in model:
        if key in state and key not in ignore_modules:  # Check if the key exists directly in the state
            params = state[key]

            if not is_distributed:
                # Strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
                for k in list(params.keys()):
                    if k.startswith('module.'):
                        params[k[len("module."):]] = params[k]
                        del params[k]

            print(f'{key} loaded')
            model[key].load_state_dict(params, strict=True)

    # Set model to evaluation mode
    _ = [model[key].eval() for key in model]

    if not load_only_params:
      epoch = state["epoch"] + 1
      iters = state["iters"]
      optimizer.load_state_dict(state["optimizer"])
      optimizer.load_scheduler_state_dict(state["scheduler"])

    else:
      epoch = state["epoch"] + 1
      iters = state["iters"]

    return model, optimizer, epoch, iters

Please kindly share the error message. If you cannot find the keys that mismatch, use strict=False and check the returned value for which keys are missing or redundant

rgxb2807 commented 1 week ago

Please kindly share the error message. If you cannot find the keys that mismatch, use strict=False and check the returned value for which keys are missing or redundant

train.py unmodified

[rank0]:   File "/audio/models/FAcodec/train.py", line 494, in <module>
[rank0]:     main(args)
[rank0]:   File "/audio/models/FAcodec/train.py", line 147, in main
[rank0]:     model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, latest_checkpoint,
[rank0]:   File "/audio/models/FAcodec/modules/commons.py", line 448, in load_checkpoint
[rank0]:     params = state['net']
[rank0]: KeyError: 'net'

train.py - using load_checkpoint from above

[rank0]:   File "/audio/models/FAcodec/train.py", line 494, in <module>
[rank0]:     main(args)
[rank0]:   File "/audio/models/FAcodec/train.py", line 147, in main
[rank0]:     model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, latest_checkpoint,
[rank0]:   File "/audio/models/FAcodec/modules/commons.py", line 488, in load_checkpoint
[rank0]:     model[key].load_state_dict(params, strict=True)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2215, in load_state_dict
[rank0]:     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
[rank0]: RuntimeError: Error(s) in loading state_dict for DistributedDataParallel:
[rank0]:    Missing key(s) in state_dict: "module.block.0.conv.conv.bias", "module.block.0.conv.conv.weight_g", "module.block.0.conv.conv.weight_v", "module.block.1.block.0.block.0.alpha", "module.block.1.block.0.block.1.conv.conv.bias", "module.block.1.block.0.block.1.conv.conv.weight_g", "module.block.1.block.0.block.1.conv.conv.weight_v", "module.block.1.block.0.block.2.alpha", "module.block.1.block.0.block.3.conv.conv.bias", "module.block.1.block.0.block.3.conv.conv.weight_g", "module.block.1.block.0.block.3.conv.conv.weight_v", "module.block.1.block.1.block.0.alpha", "module.block.1.block.1.block.1.conv.conv.bias", "module.block.1.block.1.block.1.conv.conv.weight_g", "module.block.1.block.1.block.1.conv.conv.weight_v", "module.block.1.block.1.block.2.alpha", "module.block.1.block.1.block.3.conv.conv.bias", "module.block.1.block.1.block.3.conv.conv.weight_g", "module.block.1.block.1.block.3.conv.conv.weight_v", "module.block.1.block.2.block.0.alpha", "module.block.1.block.2.block.1.conv.conv.bias", "module.block.1.block.2.block.1.conv.conv.weight_g", "module.block.1.block.2.block.1.conv.conv.weight_v", "module.block.1.block.2.block.2.alpha", "module.block.1.block.2.block.3.conv.conv.bias", "module.block.1.block.2.block.3.conv.conv.weight_g", "module.block.1.block.2.block.3.conv.conv.weight_v", "module.block.1.block.3.alpha", "module.block.1.block.4.conv.conv.bias", "module.block.1.block.4.conv.conv.weight_g", "module.block.1.block.4.conv.conv.weight_v", "module.block.2.block.0.block.0.alpha", "module.block.2.block.0.block.1.conv.conv.bias", "module.block.2.block.0.block.1.conv.conv.weight_g", "module.block.2.block.0.block.1.conv.conv.weight_v", "module.block.2.block.0.block.2.alpha", "module.block.2.block.0.block.3.conv.conv.bias", "module.block.2.block.0.block.3.conv.conv.weight_g", "module.block.2.block.0.block.3.conv.conv.weight_v", "module.block.2.block.1.block.0.alpha", "module.block.2.block.1.block.1.conv.conv.bias", "module.block.2.block.1.block.1.conv.conv.weight_g", "module.block.2.block.1.block.1.conv.conv.weight_v", "module.block.2.block.1.block.2.alpha", "module.block.2.block.1.block.3.conv.conv.bias", "module.block.2.block.1.block.3.conv.conv.weight_g", "module.block.2.block.1.block.3.conv.conv.weight_v", "module.block.2.block.2.block.0.alpha", "module.block.2.block.2.block.1.conv.conv.bias", "module.block.2.block.2.block.1.conv.conv.weight_g", "module.block.2.block.2.block.1.conv.conv.weight_v", "module.block.2.block.2.block.2.alpha", "module.block.2.block.2.block.3.conv.conv.bias", "module.block.2.block.2.block.3.conv.conv.weight_g", "module.block.2.block.2.block.3.conv.conv.weight_v", "module.block.2.block.3.alpha", "module.block.2.block.4.conv.conv.bias", "module.block.2.block.4.conv.conv.weight_g", "module.block.2.block.4.conv.conv.weight_v", "module.block.3.block.0.block.0.alpha", "module.block.3.block.0.block.1.conv.conv.bias", "module.block.3.block.0.block.1.conv.conv.weight_g", "module.block.3.block.0.block.1.conv.conv.weight_v", "module.block.3.block.0.block.2.alpha", "module.block.3.block.0.block.3.conv.conv.bias", "module.block.3.block.0.block.3.conv.conv.weight_g", "module.block.3.block.0.block.3.conv.conv.weight_v", "module.block.3.block.1.block.0.alpha", "module.block.3.block.1.block.1.conv.conv.bias", "module.block.3.block.1.block.1.conv.conv.weight_g", "module.block.3.block.1.block.1.conv.conv.weight_v", "module.block.3.block.1.block.2.alpha", "module.block.3.block.1.block.3.conv.conv.bias", "module.block.3.block.1.block.3.conv.conv.weight_g", "module.block.3.block.1.block.3.conv.conv.weight_v", "module.block.3.block.2.block.0.alpha", "module.block.3.block.2.block.1.conv.conv.bias", "module.block.3.block.2.block.1.conv.conv.weight_g", "module.block.3.block.2.block.1.conv.conv.weight_v", "module.block.3.block.2.block.2.alpha", "module.block.3.block.2.block.3.conv.conv.bias", "module.block.3.block.2.block.3.conv.conv.weight_g", "module.block.3.block.2.block.3.conv.conv.weight_v", "module.block.3.block.3.alpha", "module.block.3.block.4.conv.conv.bias", "module.block.3.block.4.conv.conv.weight_g", "module.block.3.block.4.conv.conv.weight_v", "module.block.4.block.0.block.0.alpha", "module.block.4.block.0.block.1.conv.conv.bias", "module.block.4.block.0.block.1.conv.conv.weight_g", "module.block.4.block.0.block.1.conv.conv.weight_v", "module.block.4.block.0.block.2.alpha", "module.block.4.block.0.block.3.conv.conv.bias", "module.block.4.block.0.block.3.conv.conv.weight_g", "module.block.4.block.0.block.3.conv.conv.weight_v", "module.block.4.block.1.block.0.alpha", "module.block.4.block.1.block.1.conv.conv.bias", "module.block.4.block.1.block.1.conv.conv.weight_g", "module.block.4.block.1.block.1.conv.conv.weight_v", "module.block.4.block.1.block.2.alpha", "module.block.4.block.1.block.3.conv.conv.bias", "module.block.4.block.1.block.3.conv.conv.weight_g", "module.block.4.block.1.block.3.conv.conv.weight_v", "module.block.4.block.2.block.0.alpha", "module.block.4.block.2.block.1.conv.conv.bias", "module.block.4.block.2.block.1.conv.conv.weight_g", "module.block.4.block.2.block.1.conv.conv.weight_v", "module.block.4.block.2.block.2.alpha", "module.block.4.block.2.block.3.conv.conv.bias", "module.block.4.block.2.block.3.conv.conv.weight_g", "module.block.4.block.2.block.3.conv.conv.weight_v", "module.block.4.block.3.alpha", "module.block.4.block.4.conv.conv.bias", "module.block.4.block.4.conv.conv.weight_g", "module.block.4.block.4.conv.conv.weight_v", "module.block.5.lstm.weight_ih_l0", "module.block.5.lstm.weight_hh_l0", "module.block.5.lstm.bias_ih_l0", "module.block.5.lstm.bias_hh_l0", "module.block.5.lstm.weight_ih_l1", "module.block.5.lstm.weight_hh_l1", "module.block.5.lstm.bias_ih_l1", "module.block.5.lstm.bias_hh_l1", "module.block.6.alpha", "module.block.7.conv.conv.bias", "module.block.7.conv.conv.weight_g", "module.block.7.conv.conv.weight_v". 
[rank0]:    Unexpected key(s) in state_dict: "block.0.conv.conv.bias", "block.0.conv.conv.weight_g", "block.0.conv.conv.weight_v", "block.1.block.0.block.0.alpha", "block.1.block.0.block.1.conv.conv.bias", "block.1.block.0.block.1.conv.conv.weight_g", "block.1.block.0.block.1.conv.conv.weight_v", "block.1.block.0.block.2.alpha", "block.1.block.0.block.3.conv.conv.bias", "block.1.block.0.block.3.conv.conv.weight_g", "block.1.block.0.block.3.conv.conv.weight_v", "block.1.block.1.block.0.alpha", "block.1.block.1.block.1.conv.conv.bias", "block.1.block.1.block.1.conv.conv.weight_g", "block.1.block.1.block.1.conv.conv.weight_v", "block.1.block.1.block.2.alpha", "block.1.block.1.block.3.conv.conv.bias", "block.1.block.1.block.3.conv.conv.weight_g", "block.1.block.1.block.3.conv.conv.weight_v", "block.1.block.2.block.0.alpha", "block.1.block.2.block.1.conv.conv.bias", "block.1.block.2.block.1.conv.conv.weight_g", "block.1.block.2.block.1.conv.conv.weight_v", "block.1.block.2.block.2.alpha", "block.1.block.2.block.3.conv.conv.bias", "block.1.block.2.block.3.conv.conv.weight_g", "block.1.block.2.block.3.conv.conv.weight_v", "block.1.block.3.alpha", "block.1.block.4.conv.conv.bias", "block.1.block.4.conv.conv.weight_g", "block.1.block.4.conv.conv.weight_v", "block.2.block.0.block.0.alpha", "block.2.block.0.block.1.conv.conv.bias", "block.2.block.0.block.1.conv.conv.weight_g", "block.2.block.0.block.1.conv.conv.weight_v", "block.2.block.0.block.2.alpha", "block.2.block.0.block.3.conv.conv.bias", "block.2.block.0.block.3.conv.conv.weight_g", "block.2.block.0.block.3.conv.conv.weight_v", "block.2.block.1.block.0.alpha", "block.2.block.1.block.1.conv.conv.bias", "block.2.block.1.block.1.conv.conv.weight_g", "block.2.block.1.block.1.conv.conv.weight_v", "block.2.block.1.block.2.alpha", "block.2.block.1.block.3.conv.conv.bias", "block.2.block.1.block.3.conv.conv.weight_g", "block.2.block.1.block.3.conv.conv.weight_v", "block.2.block.2.block.0.alpha", "block.2.block.2.block.1.conv.conv.bias", "block.2.block.2.block.1.conv.conv.weight_g", "block.2.block.2.block.1.conv.conv.weight_v", "block.2.block.2.block.2.alpha", "block.2.block.2.block.3.conv.conv.bias", "block.2.block.2.block.3.conv.conv.weight_g", "block.2.block.2.block.3.conv.conv.weight_v", "block.2.block.3.alpha", "block.2.block.4.conv.conv.bias", "block.2.block.4.conv.conv.weight_g", "block.2.block.4.conv.conv.weight_v", "block.3.block.0.block.0.alpha", "block.3.block.0.block.1.conv.conv.bias", "block.3.block.0.block.1.conv.conv.weight_g", "block.3.block.0.block.1.conv.conv.weight_v", "block.3.block.0.block.2.alpha", "block.3.block.0.block.3.conv.conv.bias", "block.3.block.0.block.3.conv.conv.weight_g", "block.3.block.0.block.3.conv.conv.weight_v", "block.3.block.1.block.0.alpha", "block.3.block.1.block.1.conv.conv.bias", "block.3.block.1.block.1.conv.conv.weight_g", "block.3.block.1.block.1.conv.conv.weight_v", "block.3.block.1.block.2.alpha", "block.3.block.1.block.3.conv.conv.bias", "block.3.block.1.block.3.conv.conv.weight_g", "block.3.block.1.block.3.conv.conv.weight_v", "block.3.block.2.block.0.alpha", "block.3.block.2.block.1.conv.conv.bias", "block.3.block.2.block.1.conv.conv.weight_g", "block.3.block.2.block.1.conv.conv.weight_v", "block.3.block.2.block.2.alpha", "block.3.block.2.block.3.conv.conv.bias", "block.3.block.2.block.3.conv.conv.weight_g", "block.3.block.2.block.3.conv.conv.weight_v", "block.3.block.3.alpha", "block.3.block.4.conv.conv.bias", "block.3.block.4.conv.conv.weight_g", "block.3.block.4.conv.conv.weight_v", "block.4.block.0.block.0.alpha", "block.4.block.0.block.1.conv.conv.bias", "block.4.block.0.block.1.conv.conv.weight_g", "block.4.block.0.block.1.conv.conv.weight_v", "block.4.block.0.block.2.alpha", "block.4.block.0.block.3.conv.conv.bias", "block.4.block.0.block.3.conv.conv.weight_g", "block.4.block.0.block.3.conv.conv.weight_v", "block.4.block.1.block.0.alpha", "block.4.block.1.block.1.conv.conv.bias", "block.4.block.1.block.1.conv.conv.weight_g", "block.4.block.1.block.1.conv.conv.weight_v", "block.4.block.1.block.2.alpha", "block.4.block.1.block.3.conv.conv.bias", "block.4.block.1.block.3.conv.conv.weight_g", "block.4.block.1.block.3.conv.conv.weight_v", "block.4.block.2.block.0.alpha", "block.4.block.2.block.1.conv.conv.bias", "block.4.block.2.block.1.conv.conv.weight_g", "block.4.block.2.block.1.conv.conv.weight_v", "block.4.block.2.block.2.alpha", "block.4.block.2.block.3.conv.conv.bias", "block.4.block.2.block.3.conv.conv.weight_g", "block.4.block.2.block.3.conv.conv.weight_v", "block.4.block.3.alpha", "block.4.block.4.conv.conv.bias", "block.4.block.4.conv.conv.weight_g", "block.4.block.4.conv.conv.weight_v", "block.5.lstm.weight_ih_l0", "block.5.lstm.weight_hh_l0", "block.5.lstm.bias_ih_l0", "block.5.lstm.bias_hh_l0", "block.5.lstm.weight_ih_l1", "block.5.lstm.weight_hh_l1", "block.5.lstm.bias_ih_l1", "block.5.lstm.bias_hh_l1", "block.6.alpha", "block.7.conv.conv.bias", "block.7.conv.conv.weight_g", "block.7.conv.conv.weight_v". 

train_recoder.py - unmodified

[rank0]: Traceback (most recent call last):
[rank0]:   File "/audio/models/FAcodec/train.py", line 494, in <module>
[rank0]:     main(args)
[rank0]:   File "/audio/models/FAcodec/train.py", line 147, in main
[rank0]:     model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, latest_checkpoint,
[rank0]:   File "/audio/models/FAcodec/modules/commons.py", line 448, in load_checkpoint
[rank0]:     params = state['net']
[rank0]: KeyError: 'net'

train_recoder.py - using load_checkpoint from above

[rank0]: Traceback (most recent call last):
[rank0]:   File "/audio/models/FAcodec/train_redecoder.py", line 456, in <module>
[rank0]:     main(args)
[rank0]:   File "/audio/models/FAcodec/train_redecoder.py", line 94, in main
[rank0]:     codec_encoder, _, _, _ = load_checkpoint(codec_encoder, None, config['pretrained_encoder'],
[rank0]:   File "/audio/models/FAcodec/modules/commons.py", line 500, in load_checkpoint
[rank0]:     epoch = state["epoch"] + 1
[rank0]: KeyError: 'epoch'
Plachtaa commented 1 week ago

Please kindly share the error message. If you cannot find the keys that mismatch, use strict=False and check the returned value for which keys are missing or redundant

train.py unmodified

[rank0]:   File "/audio/models/FAcodec/train.py", line 494, in <module>
[rank0]:     main(args)
[rank0]:   File "/audio/models/FAcodec/train.py", line 147, in main
[rank0]:     model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, latest_checkpoint,
[rank0]:   File "/audio/models/FAcodec/modules/commons.py", line 448, in load_checkpoint
[rank0]:     params = state['net']
[rank0]: KeyError: 'net'

train.py - using load_checkpoint from above

[rank0]:   File "/audio/models/FAcodec/train.py", line 494, in <module>
[rank0]:     main(args)
[rank0]:   File "/audio/models/FAcodec/train.py", line 147, in main
[rank0]:     model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, latest_checkpoint,
[rank0]:   File "/audio/models/FAcodec/modules/commons.py", line 488, in load_checkpoint
[rank0]:     model[key].load_state_dict(params, strict=True)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2215, in load_state_dict
[rank0]:     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
[rank0]: RuntimeError: Error(s) in loading state_dict for DistributedDataParallel:
[rank0]:  Missing key(s) in state_dict: "module.block.0.conv.conv.bias", "module.block.0.conv.conv.weight_g", "module.block.0.conv.conv.weight_v", "module.block.1.block.0.block.0.alpha", "module.block.1.block.0.block.1.conv.conv.bias", "module.block.1.block.0.block.1.conv.conv.weight_g", "module.block.1.block.0.block.1.conv.conv.weight_v", "module.block.1.block.0.block.2.alpha", "module.block.1.block.0.block.3.conv.conv.bias", "module.block.1.block.0.block.3.conv.conv.weight_g", "module.block.1.block.0.block.3.conv.conv.weight_v", "module.block.1.block.1.block.0.alpha", "module.block.1.block.1.block.1.conv.conv.bias", "module.block.1.block.1.block.1.conv.conv.weight_g", "module.block.1.block.1.block.1.conv.conv.weight_v", "module.block.1.block.1.block.2.alpha", "module.block.1.block.1.block.3.conv.conv.bias", "module.block.1.block.1.block.3.conv.conv.weight_g", "module.block.1.block.1.block.3.conv.conv.weight_v", "module.block.1.block.2.block.0.alpha", "module.block.1.block.2.block.1.conv.conv.bias", "module.block.1.block.2.block.1.conv.conv.weight_g", "module.block.1.block.2.block.1.conv.conv.weight_v", "module.block.1.block.2.block.2.alpha", "module.block.1.block.2.block.3.conv.conv.bias", "module.block.1.block.2.block.3.conv.conv.weight_g", "module.block.1.block.2.block.3.conv.conv.weight_v", "module.block.1.block.3.alpha", "module.block.1.block.4.conv.conv.bias", "module.block.1.block.4.conv.conv.weight_g", "module.block.1.block.4.conv.conv.weight_v", "module.block.2.block.0.block.0.alpha", "module.block.2.block.0.block.1.conv.conv.bias", "module.block.2.block.0.block.1.conv.conv.weight_g", "module.block.2.block.0.block.1.conv.conv.weight_v", "module.block.2.block.0.block.2.alpha", "module.block.2.block.0.block.3.conv.conv.bias", "module.block.2.block.0.block.3.conv.conv.weight_g", "module.block.2.block.0.block.3.conv.conv.weight_v", "module.block.2.block.1.block.0.alpha", "module.block.2.block.1.block.1.conv.conv.bias", "module.block.2.block.1.block.1.conv.conv.weight_g", "module.block.2.block.1.block.1.conv.conv.weight_v", "module.block.2.block.1.block.2.alpha", "module.block.2.block.1.block.3.conv.conv.bias", "module.block.2.block.1.block.3.conv.conv.weight_g", "module.block.2.block.1.block.3.conv.conv.weight_v", "module.block.2.block.2.block.0.alpha", "module.block.2.block.2.block.1.conv.conv.bias", "module.block.2.block.2.block.1.conv.conv.weight_g", "module.block.2.block.2.block.1.conv.conv.weight_v", "module.block.2.block.2.block.2.alpha", "module.block.2.block.2.block.3.conv.conv.bias", "module.block.2.block.2.block.3.conv.conv.weight_g", "module.block.2.block.2.block.3.conv.conv.weight_v", "module.block.2.block.3.alpha", "module.block.2.block.4.conv.conv.bias", "module.block.2.block.4.conv.conv.weight_g", "module.block.2.block.4.conv.conv.weight_v", "module.block.3.block.0.block.0.alpha", "module.block.3.block.0.block.1.conv.conv.bias", "module.block.3.block.0.block.1.conv.conv.weight_g", "module.block.3.block.0.block.1.conv.conv.weight_v", "module.block.3.block.0.block.2.alpha", "module.block.3.block.0.block.3.conv.conv.bias", "module.block.3.block.0.block.3.conv.conv.weight_g", "module.block.3.block.0.block.3.conv.conv.weight_v", "module.block.3.block.1.block.0.alpha", "module.block.3.block.1.block.1.conv.conv.bias", "module.block.3.block.1.block.1.conv.conv.weight_g", "module.block.3.block.1.block.1.conv.conv.weight_v", "module.block.3.block.1.block.2.alpha", "module.block.3.block.1.block.3.conv.conv.bias", "module.block.3.block.1.block.3.conv.conv.weight_g", "module.block.3.block.1.block.3.conv.conv.weight_v", "module.block.3.block.2.block.0.alpha", "module.block.3.block.2.block.1.conv.conv.bias", "module.block.3.block.2.block.1.conv.conv.weight_g", "module.block.3.block.2.block.1.conv.conv.weight_v", "module.block.3.block.2.block.2.alpha", "module.block.3.block.2.block.3.conv.conv.bias", "module.block.3.block.2.block.3.conv.conv.weight_g", "module.block.3.block.2.block.3.conv.conv.weight_v", "module.block.3.block.3.alpha", "module.block.3.block.4.conv.conv.bias", "module.block.3.block.4.conv.conv.weight_g", "module.block.3.block.4.conv.conv.weight_v", "module.block.4.block.0.block.0.alpha", "module.block.4.block.0.block.1.conv.conv.bias", "module.block.4.block.0.block.1.conv.conv.weight_g", "module.block.4.block.0.block.1.conv.conv.weight_v", "module.block.4.block.0.block.2.alpha", "module.block.4.block.0.block.3.conv.conv.bias", "module.block.4.block.0.block.3.conv.conv.weight_g", "module.block.4.block.0.block.3.conv.conv.weight_v", "module.block.4.block.1.block.0.alpha", "module.block.4.block.1.block.1.conv.conv.bias", "module.block.4.block.1.block.1.conv.conv.weight_g", "module.block.4.block.1.block.1.conv.conv.weight_v", "module.block.4.block.1.block.2.alpha", "module.block.4.block.1.block.3.conv.conv.bias", "module.block.4.block.1.block.3.conv.conv.weight_g", "module.block.4.block.1.block.3.conv.conv.weight_v", "module.block.4.block.2.block.0.alpha", "module.block.4.block.2.block.1.conv.conv.bias", "module.block.4.block.2.block.1.conv.conv.weight_g", "module.block.4.block.2.block.1.conv.conv.weight_v", "module.block.4.block.2.block.2.alpha", "module.block.4.block.2.block.3.conv.conv.bias", "module.block.4.block.2.block.3.conv.conv.weight_g", "module.block.4.block.2.block.3.conv.conv.weight_v", "module.block.4.block.3.alpha", "module.block.4.block.4.conv.conv.bias", "module.block.4.block.4.conv.conv.weight_g", "module.block.4.block.4.conv.conv.weight_v", "module.block.5.lstm.weight_ih_l0", "module.block.5.lstm.weight_hh_l0", "module.block.5.lstm.bias_ih_l0", "module.block.5.lstm.bias_hh_l0", "module.block.5.lstm.weight_ih_l1", "module.block.5.lstm.weight_hh_l1", "module.block.5.lstm.bias_ih_l1", "module.block.5.lstm.bias_hh_l1", "module.block.6.alpha", "module.block.7.conv.conv.bias", "module.block.7.conv.conv.weight_g", "module.block.7.conv.conv.weight_v". 
[rank0]:  Unexpected key(s) in state_dict: "block.0.conv.conv.bias", "block.0.conv.conv.weight_g", "block.0.conv.conv.weight_v", "block.1.block.0.block.0.alpha", "block.1.block.0.block.1.conv.conv.bias", "block.1.block.0.block.1.conv.conv.weight_g", "block.1.block.0.block.1.conv.conv.weight_v", "block.1.block.0.block.2.alpha", "block.1.block.0.block.3.conv.conv.bias", "block.1.block.0.block.3.conv.conv.weight_g", "block.1.block.0.block.3.conv.conv.weight_v", "block.1.block.1.block.0.alpha", "block.1.block.1.block.1.conv.conv.bias", "block.1.block.1.block.1.conv.conv.weight_g", "block.1.block.1.block.1.conv.conv.weight_v", "block.1.block.1.block.2.alpha", "block.1.block.1.block.3.conv.conv.bias", "block.1.block.1.block.3.conv.conv.weight_g", "block.1.block.1.block.3.conv.conv.weight_v", "block.1.block.2.block.0.alpha", "block.1.block.2.block.1.conv.conv.bias", "block.1.block.2.block.1.conv.conv.weight_g", "block.1.block.2.block.1.conv.conv.weight_v", "block.1.block.2.block.2.alpha", "block.1.block.2.block.3.conv.conv.bias", "block.1.block.2.block.3.conv.conv.weight_g", "block.1.block.2.block.3.conv.conv.weight_v", "block.1.block.3.alpha", "block.1.block.4.conv.conv.bias", "block.1.block.4.conv.conv.weight_g", "block.1.block.4.conv.conv.weight_v", "block.2.block.0.block.0.alpha", "block.2.block.0.block.1.conv.conv.bias", "block.2.block.0.block.1.conv.conv.weight_g", "block.2.block.0.block.1.conv.conv.weight_v", "block.2.block.0.block.2.alpha", "block.2.block.0.block.3.conv.conv.bias", "block.2.block.0.block.3.conv.conv.weight_g", "block.2.block.0.block.3.conv.conv.weight_v", "block.2.block.1.block.0.alpha", "block.2.block.1.block.1.conv.conv.bias", "block.2.block.1.block.1.conv.conv.weight_g", "block.2.block.1.block.1.conv.conv.weight_v", "block.2.block.1.block.2.alpha", "block.2.block.1.block.3.conv.conv.bias", "block.2.block.1.block.3.conv.conv.weight_g", "block.2.block.1.block.3.conv.conv.weight_v", "block.2.block.2.block.0.alpha", "block.2.block.2.block.1.conv.conv.bias", "block.2.block.2.block.1.conv.conv.weight_g", "block.2.block.2.block.1.conv.conv.weight_v", "block.2.block.2.block.2.alpha", "block.2.block.2.block.3.conv.conv.bias", "block.2.block.2.block.3.conv.conv.weight_g", "block.2.block.2.block.3.conv.conv.weight_v", "block.2.block.3.alpha", "block.2.block.4.conv.conv.bias", "block.2.block.4.conv.conv.weight_g", "block.2.block.4.conv.conv.weight_v", "block.3.block.0.block.0.alpha", "block.3.block.0.block.1.conv.conv.bias", "block.3.block.0.block.1.conv.conv.weight_g", "block.3.block.0.block.1.conv.conv.weight_v", "block.3.block.0.block.2.alpha", "block.3.block.0.block.3.conv.conv.bias", "block.3.block.0.block.3.conv.conv.weight_g", "block.3.block.0.block.3.conv.conv.weight_v", "block.3.block.1.block.0.alpha", "block.3.block.1.block.1.conv.conv.bias", "block.3.block.1.block.1.conv.conv.weight_g", "block.3.block.1.block.1.conv.conv.weight_v", "block.3.block.1.block.2.alpha", "block.3.block.1.block.3.conv.conv.bias", "block.3.block.1.block.3.conv.conv.weight_g", "block.3.block.1.block.3.conv.conv.weight_v", "block.3.block.2.block.0.alpha", "block.3.block.2.block.1.conv.conv.bias", "block.3.block.2.block.1.conv.conv.weight_g", "block.3.block.2.block.1.conv.conv.weight_v", "block.3.block.2.block.2.alpha", "block.3.block.2.block.3.conv.conv.bias", "block.3.block.2.block.3.conv.conv.weight_g", "block.3.block.2.block.3.conv.conv.weight_v", "block.3.block.3.alpha", "block.3.block.4.conv.conv.bias", "block.3.block.4.conv.conv.weight_g", "block.3.block.4.conv.conv.weight_v", "block.4.block.0.block.0.alpha", "block.4.block.0.block.1.conv.conv.bias", "block.4.block.0.block.1.conv.conv.weight_g", "block.4.block.0.block.1.conv.conv.weight_v", "block.4.block.0.block.2.alpha", "block.4.block.0.block.3.conv.conv.bias", "block.4.block.0.block.3.conv.conv.weight_g", "block.4.block.0.block.3.conv.conv.weight_v", "block.4.block.1.block.0.alpha", "block.4.block.1.block.1.conv.conv.bias", "block.4.block.1.block.1.conv.conv.weight_g", "block.4.block.1.block.1.conv.conv.weight_v", "block.4.block.1.block.2.alpha", "block.4.block.1.block.3.conv.conv.bias", "block.4.block.1.block.3.conv.conv.weight_g", "block.4.block.1.block.3.conv.conv.weight_v", "block.4.block.2.block.0.alpha", "block.4.block.2.block.1.conv.conv.bias", "block.4.block.2.block.1.conv.conv.weight_g", "block.4.block.2.block.1.conv.conv.weight_v", "block.4.block.2.block.2.alpha", "block.4.block.2.block.3.conv.conv.bias", "block.4.block.2.block.3.conv.conv.weight_g", "block.4.block.2.block.3.conv.conv.weight_v", "block.4.block.3.alpha", "block.4.block.4.conv.conv.bias", "block.4.block.4.conv.conv.weight_g", "block.4.block.4.conv.conv.weight_v", "block.5.lstm.weight_ih_l0", "block.5.lstm.weight_hh_l0", "block.5.lstm.bias_ih_l0", "block.5.lstm.bias_hh_l0", "block.5.lstm.weight_ih_l1", "block.5.lstm.weight_hh_l1", "block.5.lstm.bias_ih_l1", "block.5.lstm.bias_hh_l1", "block.6.alpha", "block.7.conv.conv.bias", "block.7.conv.conv.weight_g", "block.7.conv.conv.weight_v". 

train_recoder.py - unmodified

[rank0]: Traceback (most recent call last):
[rank0]:   File "/audio/models/FAcodec/train.py", line 494, in <module>
[rank0]:     main(args)
[rank0]:   File "/audio/models/FAcodec/train.py", line 147, in main
[rank0]:     model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, latest_checkpoint,
[rank0]:   File "/audio/models/FAcodec/modules/commons.py", line 448, in load_checkpoint
[rank0]:     params = state['net']
[rank0]: KeyError: 'net'

train_recoder.py - using load_checkpoint from above

[rank0]: Traceback (most recent call last):
[rank0]:   File "/audio/models/FAcodec/train_redecoder.py", line 456, in <module>
[rank0]:     main(args)
[rank0]:   File "/audio/models/FAcodec/train_redecoder.py", line 94, in main
[rank0]:     codec_encoder, _, _, _ = load_checkpoint(codec_encoder, None, config['pretrained_encoder'],
[rank0]:   File "/audio/models/FAcodec/modules/commons.py", line 500, in load_checkpoint
[rank0]:     epoch = state["epoch"] + 1
[rank0]: KeyError: 'epoch'

this loading script only work for checkpoints that is saved during training, it does not work for the bin file I published on HF. To load parameters from that pretrained model to your own DDP training, you should:

  1. load parameter before accelerator.prepare() so that there won't be module prefix in parameter names
  2. set load_only_params to True because there is no epoch, iters or optimizer states info in released checkpoint