facebookresearch / audioseal

Localized watermarking for AI-generated speech audios, with SOTA on robustness and very fast detector
MIT License
463 stars 56 forks source link

Fine-tuning the pre-trained model #40

Open crisostomi opened 4 months ago

crisostomi commented 4 months ago

Hello!

I am trying to fine-tune the pre-trained model over a new dataset, but I am having a hard time finding the checkpoint to provide as a continue_from argument to dora. The provided checkpoints are in fact already separated into generator and detector, and cannot be used in the training pipeline which expects a single checkpoint for the whole model. Am I missing something?

Thanks a lot!

antoine-tran commented 4 months ago

If I understand well, you need the checkpoint used in our ICML'24 for the fine-tuning ?

crisostomi commented 4 months ago

Yes, I managed to load the pretrained generator and detector for inference, but I need the single non-preprocessed checkpoint to be able to use it as initialization when fine-tuning it with dora run ... continue_from <PATH_TO_CKPT>.

crisostomi commented 4 months ago

In case anyone might have the same problem, I quickly wrote a helper function to combine the checkpoints back into a single one

def combine(
    generator_checkpoint: str, detector_checkpoint: str, output_checkpoint: str
):
    """Combine generator and detector checkpoints into a single checkpoint."""

    gen_ckpt = torch.load(generator_checkpoint)
    det_ckpt = torch.load(detector_checkpoint)

    combined_ckpt = {
        "xp.cfg": gen_ckpt["xp.cfg"],  # assuming the configs are identical
        "model": {},
    }

    # Add generator layers with appropriate prefix
    for layer in gen_ckpt["model"].keys():
        new_layer = f"generator.{layer}"
        combined_ckpt["model"][new_layer] = gen_ckpt["model"][layer]

    # Add detector layers with appropriate prefix
    for layer in det_ckpt["model"].keys():
        new_layer = f"detector.{layer}"
        combined_ckpt["model"][new_layer] = det_ckpt["model"][layer]

    # Special case for 'msg_processor.msg_processor.weight'
    if "msg_processor.msg_processor.weight" in gen_ckpt["model"]:
        combined_ckpt["model"]["msg_processor.msg_processor.0.weight"] = gen_ckpt[
            "model"
        ]["msg_processor.msg_processor.weight"]

    torch.save(combined_ckpt, output_checkpoint)

However, using the resulting checkpoint with continue_from doesn't seem to work properly as the logged metrics seem to be as low as starting from scratch.

antoine-tran commented 4 months ago

thanks, do you want to open a PR for this ?

antoine-tran commented 4 months ago

btw, continue_from is not a real fine-tuning recipe but rather to continue the long-running / checkpointed model. It performs full updates of the model parameters and can lead to catastrophic forgetting.

crisostomi commented 4 months ago

thanks, do you want to open a PR for this ?

Sure, will do!

btw, continue_from is not a real fine-tuning recipe but rather to continue the long-running / checkpointed model. It performs full updates of the model parameters and can lead to catastrophic forgetting.

Thanks, I see, do you have any hints how you would do a fine-tuning instead?

hadyelsahar commented 4 months ago

@crisostomi I know loading ckpts works to resume training works well and tested. I would say you need to debug here to make sure it does the right thing and your weights are loaded correctly.

https://github.com/facebookresearch/audiocraft/blob/adf0b04a4452f171970028fcf80f101dd5e26e19/audiocraft/solvers/base.py#L342

Thanks, I see, do you have any hints how you would do a fine-tuning instead?

Depends on what you want to achieve, if you want to make the detector more robust then you could freeze the generator. Also make sure that dora run ... continue_from loads the correct parameters for training that you need.

crisostomi commented 2 months ago

@hadyelsahar thanks for the answer!

I guess we still miss the weights for the adversarial losses to resume the training from the pre-trained checkpoint, right?

pierrefdz commented 1 month ago

Yes indeed. Have you tried fine-tuning without the per-loaded weights for the adversarial networks?