datumbox / dapi-model-versioning

RFC for Model Versioning across all PyTorch Domain libraries
2 stars 0 forks source link

DAPI Model Versioning RFC

  1. Introduction
    1. Problem definition
    2. Objective
    3. Motivation
  2. Previous work
  3. Repository structure
  4. Design
    1. Specifications
    2. Out of scope
    3. Proposal
    4. Special cases
    5. Demos
    6. Implementation details
    7. Alternatives considered
    8. Release plan
  5. Next steps

Introduction

Problem definition

The PyTorch domain libraries don't have a standard way to perform Model Versioning. With the term "Model Versioning" we denote the problem of maintaining concurrently multiple versions of pre-trained weights and handling changes on the model code both in a Backwards Compatible (BC) and BC-breaking manner.

Objective

Establish a common approach for handling Model Versioning across all domain libraries.

Motivation

Currently all domain libraries offer comparable APIs for initializing models of popular architectures. Those that already offer pre-trained weights (TorchAudio and TorchVision) have adopted equivalent solutions [1, 2] while those that plan to introduce them for the first time (TorchText) aligned with existing practices [3, 4].

While the model building process is fairly standardized across domains, the model versioning isn't. Fortunately, the needs across all domains are very similar. We all have to support multiple weights, handle model code changes, strive for reproducibility etc. Standardizing the way we do this across domains is important because:

  1. Code changes affecting the models are common and part of the standard development cycle of all domains.
  2. Updating the model weights with more accurate ones trained using new recipes is a common problem [5]. Same applies to providing concurrently multiple versions of the weights trained on different Datasets (to cover for different taxonomies, languages etc) [6].
  3. Providing an aligned API will improve user experience while combining the Domain libraries in their training recipes.

Previous work

This is not the first time the model versioning problem arises [7] as it has previously been discussed by the domain library maintainers.

In TorchVision, model versioning related issues have been handled on a case-by-case basis [8, 9, 10, 11, 12]. We typically try to maintain BC as much as possible except in cases where the issue is considered a bug or extremely detrimental to the user experience. In the latter cases, we often deploy BC-breaking fixes but try to reduce the effects of the change as much as possible [13]. TorchAudio has only recently introduced the pre-trained models and their API is currently in beta. TorchText is currently working on introducing its first pre-trained models so their solution is in prototype phase. Finally, though PyTorch Core doesn't provide pre-trained models, occasionally there is the need of handling code changes on existing Layers and operators. These are typically handled with a mix of version parameters, deprecation warnings and method renamings [14, 15, 16]

Repository structure

This RFC comes with a companion repository. Amendments to this RFC should be made by sending a Pull Request to the repo.

The repository aims to serve as a live RFC document capable of show-casing the proposed API and utilities, providing examples of how to address the most common model-versioning scenarios and offering actual implementations for some of the real-world models included in the Domain libraries. Here is its structure:

Design

Specifications

The proposed solution must meet the following criteria:

  1. Provide an API for supporting multiple pre-trained weights and the paradigm of how to version and handle the code changes on models.
  2. Describe how the pre-trained weights link to their corresponding meta-data and to the preprocessing transforms which are necessary for using the models.
  3. Facilitate the discoverability of the various model variants (example: resnet50) along with the available pre-trained weights by reusing as many standard dev tools from the python ecosystem as possible.

The proposed solution can optionally support the following nice-to-haves:

  1. Encourage users to use the latest/greatest pre-trained weights and models.
  2. Provide utilities that simplify model-versioning in BC cases and reduce side-effects in BC-breaking cases.

Out of scope

  1. We focus on the main model building methods used for initializing the models and loading their weights. Libraries can also offer secondary options for doing so (such as constructing models directly from Classes, supporting model registration mechanisms etc) but these are out of scope for this proposal. Note that when possible, we illustrate with code examples that our solution is compatible with such mechanisms, but they are not part of our proposal.
  2. We focus on the public API and not on implementation details. Though the repo contains private helper methods used to construct models, these are not part of the proposal and libraries can adapt them to their needs and current practices.
  3. We try to keep this RFC not too opinionated but rather describe a framework that gives to the DAPI libs the space to adapt the solution on their needs. As a result, fully specifying implementation details is beyond the scope of this proposal.

Proposal

We propose using separate model builders and a weights parameter for each model version. We plan to maintain the existing model builder methods supported by all Domain libraries to construct models and use Enums with data class values to pass the information of the pre-trained weights. Each model variant will have its own method and weights. When BC-breaking changes are necessary, we will introduce new builder methods to keep things BC.

High-level API implementation in pseudocode:

@dataclass
class WeightEntry:
    url: str  # Weights URL/path
    transforms: Callable  # Preprocessing transform constructor
    meta: Dict[str, Any]  # Arbitrary Meta-Data
    # Other customizable fields go here

class ResNet50Weights(Enum):
    ImageNet1K_RefV1 = WeightEntry(
        url="https://path/to/weights.pth",
        transforms=partial(ImageNetPreprocessing, width=224),
        meta={"num_classes": 1000, "Acc@1": 76.130, "classes": [...]}
    )
    CIFAR100_RefV1 = WeightEntry(...)

def resnet50(weights: Optional[ResNet50Weights] = None) -> ResNet:
    # Model construction and load weighting goes here
    pass

# When BC-breaking changes are unavoidable, we will provide new builder methods to keep things BC.
class ResNet50V2Weights(Enum):
    ImageNet1K_RefV1 = WeightEntry(...)

def resnet50_v2(weights: Optional[ResNet50V2Weights] = None) -> ResNetV2:  # Assume new Class needed
    pass

Example of using the API:

# Model initialization and weights loading
weights = ResNet50Weights.ImageNet1K_RefV1
model = resnet50(weights=weights).eval()

# Fetch preprocessing transforms and apply them to the image
preprocess = weights.transforms()
batch = preprocess(img).unsqueeze(0)

# Make predictions
prediction = model(batch).squeeze(0).softmax(0)

# Use meta-data
label = prediction.argmax().item()
class_name = weights.meta['classes'][label]

The above approach:

Special cases

This section provides some guidance on how to handle special cases that were brought up during the review of the RFC. When multiple valid options exist, the DAPI libs should choose the one that meets their needs best:

Demos

To prove that the proposed API can accommodate all domains, we implemented it to 4 real-world models. To see the demos run the following commands from the root of the repo:

Image Classification: ResNet50

Check the model implementation and the model usage.

$ python -u image_classification.py
golden retriever 0.9381255507469177

Object Detection: FasterRCNN with ResNet50 FPN

Check the model implementation and the model usage.

$ python -u image_detection.py.py
Saving picture at ./output/object-detection.jpg

object detection

Text Encoding: Roberta Base

Check the model implementation and the model usage.

$ python -u text_encoding.py
125000000 torch.Size([1, 5, 768])

Text to Speech: Tacotron2 + WaveRNN

Check the model implementation and the model usage.

$ python -u text_to_speech.py 
38 torch.Size([1, 80, 112])
Saving wave at ./output/message.wav

Implementation details

Below we link directly to the actual implementations and code examples where we document everything extensively. The best way to see how the proposal works is to check the examples folder, where we focus on 3 Model Versioning scenarios that we had to address in the past:

  1. Multi-weight and Multi-version support (BC)
  2. Updated default Hyper-params (BC-breaking)
  3. Code change which affects the model behaviour but architecture remains the same (BC-breaking)

Our proposal consists of the following key components:

We also offer two optional components:

Alternatives considered

Here we briefly list the alternatives that we considered along with some of the reasons we didn't select them. Note that in all cases, we prefer using Enums to strings. To read more on why check this section.

Single model builder and weights parameter for all code versions

class ResNet50Weights(Enum):
    V1_NoWeights = WeightEntry(...)
    V1_ImageNet1K_RefV1 = WeightEntry(...)

def resnet50(weights: ResNet50Weights = ResNet50Weights.V1_NoWeights) -> nn.Module:
    pass

Pros:

Cons:

Single model builder, two separate arguments for the version and weights

class ResNet50Weights(Enum):
    ImageNet1K_RefV1 = WeightEntry(...)

def resnet50(version: int = 1, weights: Optional[ResNet50Weights] = None) -> nn.Module:
    pass

Pros:

Cons:

Separate model builder for each code version and weights combination

def resnet50_v2_imagenet_ref1(pretrained: bool = False) -> ResNetV2:
    pass

Pros:

Cons:

Release plan

The key components of the proposal should be independently adapted and implemented by the maintainers of each domain library. The degree to which the optional components will be adopted can vary and depends on the needs of each domain. Since we currently don't have an agreed way to handle interdependencies between domain libraries, any of the proposed utilities adopted can live in multiple domain repos and can move on a common repo if such is introduced on the future.

Next steps