Many 3rd party libraries use class factory functions to load/create models. This means users are only working with instances, not classes themselves. For example...
Timm:
import timm
model = timm.create_model('resnet18', pretrained=True)
# User fine tunes...
Torchhub:
import torch
model = torch.hub.load('pytorch/vision', 'resnet18', pretrained=True)
# User fine tunes...
Now, if we want to let users push these up to the hub, they either have to do all the saving/pushing logic manually, or they have to wrap these model instances in another class that mixes in PyTorchModelHubMixin. Wouldn't it be easier to just call a function like save_pretrained_pytorch(model, save_dir)? Then we can reuse that function in the mixin, so we're only writing the code in one place.
Keras is another great example of this. Users usually build models with the Functional or Sequential APIs instead of defining classes themselves, so a mixin doesn't help them much.
Additional context
We explore using functions for Keras here - #284
Additional inspiration - #310
Many 3rd party libraries use class factory functions to load/create models. This means users are only working with instances, not classes themselves. For example...
Timm:
Torchhub:
Now, if we want to let users push these up to the hub, they either have to do all the saving/pushing logic manually, or they have to wrap these model instances in another class that mixes in
PyTorchModelHubMixin
. Wouldn't it be easier to just call a function likesave_pretrained_pytorch(model, save_dir)
? Then we can reuse that function in the mixin, so we're only writing the code in one place.Keras is another great example of this. Users usually build models with the Functional or Sequential APIs instead of defining classes themselves, so a mixin doesn't help them much.
Additional context We explore using functions for Keras here - #284
Additional inspiration - #310