Linaqruf / kohya-trainer

Adapted from https://note.com/kohya_ss/n/nbf7ce8d80f29 for easier cloning
Apache License 2.0
1.85k stars 305 forks source link

Resume training not working, and a simple solution #322

Open TsGrolken opened 9 months ago

TsGrolken commented 9 months ago

when resuming from a checkpoint, it returns: load network weights from /content/drive/MyDrive/XXXXX.safetensors: None

I have slightly looked into the code, and find that it always provides a 'FALSE' value for the 'dtype' parameter in the 'load_weights' function, while you always feed the correct 'dtype' when saving the checkpoint. And there is a simple fix for it:

in flie 'kohya-trainer/train_network.py', starting from line 206, you can replace

    if args.network_weights is not None:
        info = network.load_weights(args.network_weights)
        print(f"load network weights from {args.network_weights}: {info}")

with:

    def load_weights_2(network, file, dtype):
        if os.path.splitext(file)[1] == ".safetensors":
            from safetensors.torch import load_file

            weights_sd = load_file(file)
        else:
            weights_sd = torch.load(file, map_location="cpu")

        info = network.load_state_dict(weights_sd, dtype)
        return info

    if args.network_weights is not None:
        info = load_weights_2(network, args.network_weights, save_dtype)
        print(f"load network weights from {args.network_weights}: {info}")

This should work for lora, locon and loha, not sure if it works for XL, and it can be easily done in Google Colab editing mode. Hope it can be fixed officially soon.