Open crisostomi opened 4 months ago
If I understand well, you need the checkpoint used in our ICML'24 for the fine-tuning ?
Yes, I managed to load the pretrained generator and detector for inference, but I need the single non-preprocessed checkpoint to be able to use it as initialization when fine-tuning it with dora run ... continue_from <PATH_TO_CKPT>
.
In case anyone might have the same problem, I quickly wrote a helper function to combine the checkpoints back into a single one
def combine(
generator_checkpoint: str, detector_checkpoint: str, output_checkpoint: str
):
"""Combine generator and detector checkpoints into a single checkpoint."""
gen_ckpt = torch.load(generator_checkpoint)
det_ckpt = torch.load(detector_checkpoint)
combined_ckpt = {
"xp.cfg": gen_ckpt["xp.cfg"], # assuming the configs are identical
"model": {},
}
# Add generator layers with appropriate prefix
for layer in gen_ckpt["model"].keys():
new_layer = f"generator.{layer}"
combined_ckpt["model"][new_layer] = gen_ckpt["model"][layer]
# Add detector layers with appropriate prefix
for layer in det_ckpt["model"].keys():
new_layer = f"detector.{layer}"
combined_ckpt["model"][new_layer] = det_ckpt["model"][layer]
# Special case for 'msg_processor.msg_processor.weight'
if "msg_processor.msg_processor.weight" in gen_ckpt["model"]:
combined_ckpt["model"]["msg_processor.msg_processor.0.weight"] = gen_ckpt[
"model"
]["msg_processor.msg_processor.weight"]
torch.save(combined_ckpt, output_checkpoint)
However, using the resulting checkpoint with continue_from
doesn't seem to work properly as the logged metrics seem to be as low as starting from scratch.
thanks, do you want to open a PR for this ?
btw, continue_from
is not a real fine-tuning recipe but rather to continue the long-running / checkpointed model. It performs full updates of the model parameters and can lead to catastrophic forgetting.
thanks, do you want to open a PR for this ?
Sure, will do!
btw, continue_from is not a real fine-tuning recipe but rather to continue the long-running / checkpointed model. It performs full updates of the model parameters and can lead to catastrophic forgetting.
Thanks, I see, do you have any hints how you would do a fine-tuning instead?
@crisostomi I know loading ckpts works to resume training works well and tested. I would say you need to debug here to make sure it does the right thing and your weights are loaded correctly.
Thanks, I see, do you have any hints how you would do a fine-tuning instead?
Depends on what you want to achieve, if you want to make the detector more robust then you could freeze the generator.
Also make sure that dora run ... continue_from
loads the correct parameters for training that you need.
@hadyelsahar thanks for the answer!
I guess we still miss the weights for the adversarial losses to resume the training from the pre-trained checkpoint, right?
Yes indeed. Have you tried fine-tuning without the per-loaded weights for the adversarial networks?
Hello!
I am trying to fine-tune the pre-trained model over a new dataset, but I am having a hard time finding the checkpoint to provide as a
continue_from
argument todora
. The provided checkpoints are in fact already separated into generator and detector, and cannot be used in the training pipeline which expects a single checkpoint for the whole model. Am I missing something?Thanks a lot!