The PyTorch domain libraries don't have a standard way to perform Model Versioning. With the term "Model Versioning" we denote the problem of maintaining concurrently multiple versions of pre-trained weights and handling changes on the model code both in a Backwards Compatible (BC) and BC-breaking manner.
Establish a common approach for handling Model Versioning across all domain libraries.
Currently all domain libraries offer comparable APIs for initializing models of popular architectures. Those that already offer pre-trained weights (TorchAudio and TorchVision) have adopted equivalent solutions [1, 2] while those that plan to introduce them for the first time (TorchText) aligned with existing practices [3, 4].
While the model building process is fairly standardized across domains, the model versioning isn't. Fortunately, the needs across all domains are very similar. We all have to support multiple weights, handle model code changes, strive for reproducibility etc. Standardizing the way we do this across domains is important because:
This is not the first time the model versioning problem arises [7] as it has previously been discussed by the domain library maintainers.
In TorchVision, model versioning related issues have been handled on a case-by-case basis [8, 9, 10, 11, 12]. We typically try to maintain BC as much as possible except in cases where the issue is considered a bug or extremely detrimental to the user experience. In the latter cases, we often deploy BC-breaking fixes but try to reduce the effects of the change as much as possible [13]. TorchAudio has only recently introduced the pre-trained models and their API is currently in beta. TorchText is currently working on introducing its first pre-trained models so their solution is in prototype phase. Finally, though PyTorch Core doesn't provide pre-trained models, occasionally there is the need of handling code changes on existing Layers and operators. These are typically handled with a mix of version parameters, deprecation warnings and method renamings [14, 15, 16]
This RFC comes with a companion repository. Amendments to this RFC should be made by sending a Pull Request to the repo.
The repository aims to serve as a live RFC document capable of show-casing the proposed API and utilities, providing examples of how to address the most common model-versioning scenarios and offering actual implementations for some of the real-world models included in the Domain libraries. Here is its structure:
README.md
file serves as the main RFC document.examples
folder contains standalone implementations for the most common model versioning scenarios that we've
faced before. On the top of each scenario file we include a description and an example with references to a real-world
case. We recommend starting from there.dapi_lib
package:
dapi_lib/models/_api.py
contains the majority of the utilities used by the API.dapi_lib/models/*.py
are implementations of real-world models from different domains.dapi_lib/datasets
and dapi_lib/transforms
packages contain code taken from the domain libs and adapted
for the needs of this RFC. These are purely there to make the demos run smoothly and they should not be considered
part of this proposal.*.py
files located at the root of the repo. They target to show how the API
looks from the user perspective:
image_classification.py
and image_detection.py
show-case the new API on Vision.text_encoding.py
gives us an example on how Text could structure its models. Note that because TorchText
currently doesn't provide pre-trained models on the public repo, we use huggingface's Roberta.text_to_speech.py
provides an example of implementing the new API on Audio.assets
folder contains a couple of assets necessary for the demos.output
folder will be created on the root of the project after running the demos.third_party
package contains copy-pasted code necessary for the demos and it is not part of this RFC.requirements.txt
file contains a list of all dependencies for running the code on this repo.The proposed solution must meet the following criteria:
resnet50
) along with the available
pre-trained weights by reusing as many standard dev tools from the python ecosystem as possible.The proposed solution can optionally support the following nice-to-haves:
We propose using separate model builders and a weights parameter for each model version. We plan to maintain the existing model builder methods supported by all Domain libraries to construct models and use Enums with data class values to pass the information of the pre-trained weights. Each model variant will have its own method and weights. When BC-breaking changes are necessary, we will introduce new builder methods to keep things BC.
High-level API implementation in pseudocode:
@dataclass
class WeightEntry:
url: str # Weights URL/path
transforms: Callable # Preprocessing transform constructor
meta: Dict[str, Any] # Arbitrary Meta-Data
# Other customizable fields go here
class ResNet50Weights(Enum):
ImageNet1K_RefV1 = WeightEntry(
url="https://path/to/weights.pth",
transforms=partial(ImageNetPreprocessing, width=224),
meta={"num_classes": 1000, "Acc@1": 76.130, "classes": [...]}
)
CIFAR100_RefV1 = WeightEntry(...)
def resnet50(weights: Optional[ResNet50Weights] = None) -> ResNet:
# Model construction and load weighting goes here
pass
# When BC-breaking changes are unavoidable, we will provide new builder methods to keep things BC.
class ResNet50V2Weights(Enum):
ImageNet1K_RefV1 = WeightEntry(...)
def resnet50_v2(weights: Optional[ResNet50V2Weights] = None) -> ResNetV2: # Assume new Class needed
pass
Example of using the API:
# Model initialization and weights loading
weights = ResNet50Weights.ImageNet1K_RefV1
model = resnet50(weights=weights).eval()
# Fetch preprocessing transforms and apply them to the image
preprocess = weights.transforms()
batch = preprocess(img).unsqueeze(0)
# Make predictions
prediction = model(batch).squeeze(0).softmax(0)
# Use meta-data
label = prediction.argmax().item()
class_name = weights.meta['classes'][label]
The above approach:
This section provides some guidance on how to handle special cases that were brought up during the review of the RFC. When multiple valid options exist, the DAPI libs should choose the one that meets their needs best:
preprocessing
and one for postprocessing
.preprocessing
and postprocessing
on
their forward()
/__call()__
method. Then we will offer two separate fields for them on the WeightEntry
class.backbone
to the
composite Model classes but this
topic is beyond the scope of this RFC.To prove that the proposed API can accommodate all domains, we implemented it to 4 real-world models. To see the demos run the following commands from the root of the repo:
Check the model implementation and the model usage.
$ python -u image_classification.py
golden retriever 0.9381255507469177
Check the model implementation and the model usage.
$ python -u image_detection.py.py
Saving picture at ./output/object-detection.jpg
Check the model implementation and the model usage.
$ python -u text_encoding.py
125000000 torch.Size([1, 5, 768])
Check the model implementation and the model usage.
$ python -u text_to_speech.py
38 torch.Size([1, 80, 112])
Saving wave at ./output/message.wav
Below we link directly to the actual implementations and code examples where we document everything extensively. The
best way to see how the proposal works is to check the examples
folder, where we focus on 3 Model Versioning scenarios
that we had to address in the past:
Our proposal consists of the following key components:
We also offer two optional components:
Here we briefly list the alternatives that we considered along with some of the reasons we didn't select them. Note that in all cases, we prefer using Enums to strings. To read more on why check this section.
class ResNet50Weights(Enum):
V1_NoWeights = WeightEntry(...)
V1_ImageNet1K_RefV1 = WeightEntry(...)
def resnet50(weights: ResNet50Weights = ResNet50Weights.V1_NoWeights) -> nn.Module:
pass
Pros:
Cons:
class ResNet50Weights(Enum):
ImageNet1K_RefV1 = WeightEntry(...)
def resnet50(version: int = 1, weights: Optional[ResNet50Weights] = None) -> nn.Module:
pass
Pros:
Cons:
version
is compatible with which weights
enum.def resnet50_v2_imagenet_ref1(pretrained: bool = False) -> ResNetV2:
pass
Pros:
Cons:
The key components of the proposal should be independently adapted and implemented by the maintainers of each domain library. The degree to which the optional components will be adopted can vary and depends on the needs of each domain. Since we currently don't have an agreed way to handle interdependencies between domain libraries, any of the proposed utilities adopted can live in multiple domain repos and can move on a common repo if such is introduced on the future.