prosperolo / GST

Official implementation of "GST: Precise 3D Human Body from a Single Image with Gaussian Splatting Transformers"
https://abdullahamdi.com/gst/
BSD 3-Clause "New" or "Revised" License
66 stars 7 forks source link

Improve HF integration #5

Open NielsRogge opened 1 month ago

NielsRogge commented 1 month ago

Hi @prosperolo,

Thanks for this nice work! Niels here from HF.

I noticed you already use the 🤗 hub for loading the model, which is great!

This PR aims to improve the integration by:

It leverages the PyTorchModelHubMixin class which allows to inherits these methods.

Usage is as follows:

from scene.hmr2_extension import load_hmr_predictor, GaussianHMRPredictor

# define model
model = load_hmr_predictor(...)

# equip with weights
model.load_state_dict(...)

# push to hub
model.push_to_hub("your-hf-username-or-org/your-model")

# reload
model = GaussianHMRPredictor.from_pretrained("your-hf-username-or-org/your-model")

This means people don't need to manually download a checkpoint first in their local environment, it just loads automatically from the hub.

Would you be interested in this integration?

Kind regards,

Niels