learn2phoenix / CSD

MIT License
105 stars 7 forks source link

Add HF integration #7

Closed NielsRogge closed 3 weeks ago

NielsRogge commented 1 month ago

Hi @learn2phoenix,

Thanks for this nice work! I wrote a quick PoC to showcase that you can easily have integration with the 🤗 hub so that you can automatically load the model using from_pretrained (and push it using push_to_hub), track download numbers for your models (similar to models in the Transformers library), and have nice model cards on a per-model basis, and perhaps most importantly, leverage safetensors for the weights in favor of pickle.

Also, this greatly improves the discoverability of your model (as it's currently hosted on Google Drive which is hard to find).

It leverages the PyTorchModelHubMixin class which allows to inherits these methods.

Usage is as follows:

from CSD.model import CSD_CLIP
from huggingface_hub import hf_hub_download
import torch

# instantiate model
model = CSD_CLIP("resnet50", "default")

# equip model with weights
filepath = hf_hub_download(repo_id="tomg-group-umd/CSD-ViT-L", filename="checkpoint.pth", repo_type="model")
state_dict = torch.load(filepath, map_location="cpu)
model.load_state_dict(state_dict)

# push to hub
model.push_to_hub("your-hf-org-or-username/csd-clip")

# reload
model = CSD_CLIP.from_pretrained("your-hf-org-or-username/csd-clip")

This means people don't need to manually download a checkpoint first in their local environment, it just loads automatically from the hub.

Would you be interested in this integration?

Kind regards,

Niels

Note

Please don't merge this PR before pushing the model to the hub :)

learn2phoenix commented 1 month ago

@NielsRogge Thanks for suggesting the integration and I agree it is valuable.

The model is available at https://huggingface.co/tomg-group-umd/CSD-ViT-L Could you update the PR to work with this?

NielsRogge commented 1 month ago

Hi, thanks for linking the model.

Sure, although would it be possible to try out the script above using my branch? I've updated it accordingly with loading the weights before pushing to the hub.

learn2phoenix commented 1 month ago

@NielsRogge

This seems to work with following changes in the usage you suggested:

Include the following function in the CSD/model.py

def load_model_state(model, state_dict_path):
    # Load the full state dictionary
    full_state_dict = torch.load(state_dict_path)

    # Extract only the model state dictionary
    if "model_state_dict" in full_state_dict:
        model_state_dict = full_state_dict["model_state_dict"]
    else:
        model_state_dict = full_state_dict  # Assume it's already just the model state

    # Load the state dictionary, ignoring missing keys
    model.load_state_dict(model_state_dict, strict=False)

    return model
model = CSD_CLIP("vit_large", "default")
# equip model with weights
filepath = hf_hub_download(repo_id="tomg-group-umd/CSD-ViT-L", filename="checkpoint.pth", repo_type="model")
model = CSD_CLIP(name='vit_large', content_proj_head='default')
model = load_model_state(model, 'path/to/your/state_dict.pth')

Could you include the usage instructions in README.md? ALso, let's remove the push_to_hub usage. That is something the users won't be able to do on our repi

Could you make these changes and I can then merge the PR

NielsRogge commented 1 month ago

Cool! Ok yes I'll remove the push_to_hub, this was just to show how you could push the weights to https://huggingface.co/tomg-group-umd/CSD-ViT-L.

Would you be able to overwrite the weights of that repo? A model.safetensors file will be pushed in case you call push_to_hub

learn2phoenix commented 1 month ago

I don't think we need to overwrite. That repo has limited write access, so all's safe there.

NielsRogge commented 1 month ago

Oh ok, to which repository could you push the weights in that case? Do you have a personal HF profile?

learn2phoenix commented 1 month ago

Why do we want to push the weights? I have already pushed the CSD weights to this repo for users to use huggingface. Am I missing something here, I am sorry my experience with latest on huggingface is limited

NielsRogge commented 1 month ago

The reason I would overwrite the model repo (which is the purpose of this PR) is

Screenshot 2024-09-02 at 17 55 10
learn2phoenix commented 1 month ago

Fair enough. I'm beginning to understand how HF pages work, thanks for pointing me in this direction. Following are the things that I've done:

Overwriting the model repo by the user is still not required. Let me know if you think otherwise

NielsRogge commented 1 month ago

Ok great :) some remarks:

Regarding the purpose of this PR, it would allow to automatically load the model from the hub, without requiring users do download a checkpoint separately and putting it in a folder. Could you clarify where you leveraged a pretrained call?

learn2phoenix commented 1 month ago

Thanks for the review :)

I did not really leverage from_pretrained in the current code as we have released weights only for ViT-L model and I am not sure if it makes too much sense to have a special case for just this model

NielsRogge commented 1 month ago

Thanks for the quick updates!

If I understand correctly, you made changes to the existing HF model repo so that this branch works out-of-the-box? I just tried the following (here's a reproducer):

from CSD.model import CSD_CLIP

# load from the hub
model = CSD_CLIP.from_pretrained("tomg-group-umd/CSD-ViT-L")

but that results in:

ModuleNotFoundError: No module named 'CSD.model'
learn2phoenix commented 1 month ago

strange. This works for me:

image

I too ran the exact same code :

from CSD.model import CSD_CLIP

# load from the hub
model = CSD_CLIP.from_pretrained("tomg-group-umd/CSD-ViT-L")

Any idea what could be happening here?

learn2phoenix commented 3 weeks ago

@NielsRogge Were you able to make this work, or are you still facing issues? I would like to solve this together and close this PR.

NielsRogge commented 3 weeks ago

It's working fine for me, so if you are up for it, this PR can be merged.

The model has been downloaded 28 times last month I see: https://huggingface.co/tomg-group-umd/CSD-ViT-L.

Do note that wikiart is also available here: https://huggingface.co/datasets/huggan/wikiart, which means that generating embeddings can now be done in just the following lines :)

from datasets import load_dataset
from CSD.model import CSD_CLIP
import torch
from torchvision import transforms

transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
])

@torch.no_grad
def embed_batch(batch):
     images = batch["image"]

     # prepare for the model
     images = [transform(image) for image in images]
     images = torch.stack(images, dim=0)

     # embed
     embeddings = model(images)
     batch["embeddings"] = embeddings.numpy()

     return batch

# load from the hub
model = CSD_CLIP.from_pretrained("tomg-group-umd/CSD-ViT-L")

# load wikiart
dataset = load_dataset("huggan/wikiart")

# embed
dataset_with_embeddings = dataset.map(embed_batch, batched=True, batch_size=2)

This is just a PoC (image transforms are to be checked) and an FYI :) might be cool to add to the model card.

learn2phoenix commented 3 weeks ago

Thanks @NielsRogge for this PR. Appreciate the continuous feedback.