Closed NielsRogge closed 3 weeks ago
@NielsRogge Thanks for suggesting the integration and I agree it is valuable.
The model is available at https://huggingface.co/tomg-group-umd/CSD-ViT-L Could you update the PR to work with this?
Hi, thanks for linking the model.
Sure, although would it be possible to try out the script above using my branch? I've updated it accordingly with loading the weights before pushing to the hub.
@NielsRogge
This seems to work with following changes in the usage you suggested:
Include the following function in the CSD/model.py
def load_model_state(model, state_dict_path):
# Load the full state dictionary
full_state_dict = torch.load(state_dict_path)
# Extract only the model state dictionary
if "model_state_dict" in full_state_dict:
model_state_dict = full_state_dict["model_state_dict"]
else:
model_state_dict = full_state_dict # Assume it's already just the model state
# Load the state dictionary, ignoring missing keys
model.load_state_dict(model_state_dict, strict=False)
return model
model = CSD_CLIP("vit_large", "default")
# equip model with weights
filepath = hf_hub_download(repo_id="tomg-group-umd/CSD-ViT-L", filename="checkpoint.pth", repo_type="model")
model = CSD_CLIP(name='vit_large', content_proj_head='default')
model = load_model_state(model, 'path/to/your/state_dict.pth')
Could you include the usage instructions in README.md? ALso, let's remove the push_to_hub usage. That is something the users won't be able to do on our repi
Could you make these changes and I can then merge the PR
Cool! Ok yes I'll remove the push_to_hub
, this was just to show how you could push the weights to https://huggingface.co/tomg-group-umd/CSD-ViT-L.
Would you be able to overwrite the weights of that repo? A model.safetensors
file will be pushed in case you call push_to_hub
I don't think we need to overwrite. That repo has limited write access, so all's safe there.
Oh ok, to which repository could you push the weights in that case? Do you have a personal HF profile?
Why do we want to push the weights? I have already pushed the CSD weights to this repo for users to use huggingface. Am I missing something here, I am sorry my experience with latest on huggingface is limited
The reason I would overwrite the model repo (which is the purpose of this PR) is
model = CSD_CLIP.from_pretrained("your-hf-org-or-username/csd-clip")
).from_pretrained
, the download counter will go up by one) + people will find it at Fair enough. I'm beginning to understand how HF pages work, thanks for pointing me in this direction. Following are the things that I've done:
checkpoint.pth
to pytorch_model.bin
. This makes the from_pretrained
workconfig.json
. Downloads should have been tracked now but it says 0 (which is an improvement over earlier not being tracked at all) even after I successfully downloaded through pretrained
call. Can you help with this?Overwriting the model repo by the user is still not required. Let me know if you think otherwise
Ok great :) some remarks:
Regarding the purpose of this PR, it would allow to automatically load the model from the hub, without requiring users do download a checkpoint separately and putting it in a folder. Could you clarify where you leveraged a pretrained
call?
Thanks for the review :)
I did not really leverage from_pretrained
in the current code as we have released weights only for ViT-L model and I am not sure if it makes too much sense to have a special case for just this model
Thanks for the quick updates!
If I understand correctly, you made changes to the existing HF model repo so that this branch works out-of-the-box? I just tried the following (here's a reproducer):
from CSD.model import CSD_CLIP
# load from the hub
model = CSD_CLIP.from_pretrained("tomg-group-umd/CSD-ViT-L")
but that results in:
ModuleNotFoundError: No module named 'CSD.model'
strange. This works for me:
I too ran the exact same code :
from CSD.model import CSD_CLIP
# load from the hub
model = CSD_CLIP.from_pretrained("tomg-group-umd/CSD-ViT-L")
Any idea what could be happening here?
@NielsRogge Were you able to make this work, or are you still facing issues? I would like to solve this together and close this PR.
It's working fine for me, so if you are up for it, this PR can be merged.
The model has been downloaded 28 times last month I see: https://huggingface.co/tomg-group-umd/CSD-ViT-L.
Do note that wikiart is also available here: https://huggingface.co/datasets/huggan/wikiart, which means that generating embeddings can now be done in just the following lines :)
from datasets import load_dataset
from CSD.model import CSD_CLIP
import torch
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
@torch.no_grad
def embed_batch(batch):
images = batch["image"]
# prepare for the model
images = [transform(image) for image in images]
images = torch.stack(images, dim=0)
# embed
embeddings = model(images)
batch["embeddings"] = embeddings.numpy()
return batch
# load from the hub
model = CSD_CLIP.from_pretrained("tomg-group-umd/CSD-ViT-L")
# load wikiart
dataset = load_dataset("huggan/wikiart")
# embed
dataset_with_embeddings = dataset.map(embed_batch, batched=True, batch_size=2)
This is just a PoC (image transforms are to be checked) and an FYI :) might be cool to add to the model card.
Thanks @NielsRogge for this PR. Appreciate the continuous feedback.
Hi @learn2phoenix,
Thanks for this nice work! I wrote a quick PoC to showcase that you can easily have integration with the 🤗 hub so that you can automatically load the model using
from_pretrained
(and push it usingpush_to_hub
), track download numbers for your models (similar to models in the Transformers library), and have nice model cards on a per-model basis, and perhaps most importantly, leveragesafetensors
for the weights in favor of pickle.Also, this greatly improves the discoverability of your model (as it's currently hosted on Google Drive which is hard to find).
It leverages the PyTorchModelHubMixin class which allows to inherits these methods.
Usage is as follows:
This means people don't need to manually download a checkpoint first in their local environment, it just loads automatically from the hub.
Would you be interested in this integration?
Kind regards,
Niels
Note
Please don't merge this PR before pushing the model to the hub :)