kornia / kornia

Geometric Computer Vision Library for Spatial AI
https://kornia.readthedocs.io
Apache License 2.0
9.8k stars 956 forks source link

Add models weights to cache the test suite #2987

Open johnnv1 opened 3 weeks ago

johnnv1 commented 3 weeks ago

We should add the links of the weights used internally by kornia on the script https://github.com/kornia/kornia/blob/93114bf3f499eaac7c5f0f25f3e53ec356b191e2/.github/download-models-weights.py#L6 so these weights can be cached between the test jobs on the github actions.

shijianjian commented 3 weeks ago

Say for StableDiffusion in diffusers. I would like to download model weights, and all other files. Can I import diffusers directly here and make it download? Such as:

import argparse

import torch
import diffusers

fonts = {
    "sold2_wireframe": ("torchhub", "http://cmp.felk.cvut.cz/~mishkdmy/models/sold2_wireframe.pth"),
    "stablediffusion-v1.4": ("StableDiffusionPipeline", "CompVis/stable-diffusion-v1-4"),
}

if __name__ == "__main__":
    parser = argparse.ArgumentParser("WeightsDownloader")
    parser.add_argument("--target_directory", "-t", required=False, default="target_directory")

    args = parser.parse_args()

    torch.hub.set_dir(args.target_directory)

    for name, (src, url) in fonts.items():
        print(f"Downloading weights of `{name}` from `{src} ({ur}l)`. Caching to dir `{args.target_directory}`")
        if src == "torchhub":
            torch.hub.load_state_dict_from_url(url, model_dir=args.target_directory, map_location=torch.device("cpu"))
        if src == "StableDiffusionPipeline":
            StableDiffusionPipeline.from_pretrained(url)
        raise ValueError

    raise SystemExit(0)