datumbox / dapi-model-versioning

RFC for Model Versioning across all PyTorch Domain libraries
2 stars 0 forks source link

Model Architecture Configuration specification #9

Open parmeet opened 2 years ago

parmeet commented 2 years ago

One of the common cases in text is to define base model architecture and create bigger version just by increasing the number of parameters in terms of number of layers, hidden dimensions etc. Take XLMR Model for instance. There are four variations of the model dubbed as "xlmr.base", "xlmr.large", "xlmr.xl", and "xlmr.xxl"

One way to provide these models to users is to have 4 different factory functions for each one of them. But the code is highly redundant since the only different here is the input configuration. One of the better ways would be to encode this information directly inside Weights Enum, such that user facing function only need to specify which weights to use, and internally model factory function will create the corresponding architecture for the user.

I wonder if conceptually meta argument is the right place to specify model configuration, or is it only reserved for informative attributes?

mthrok commented 2 years ago

I feel the same dilemma.

have 4 different factory functions for each one of them

I think this can be useful if users want to build un-trained model.

I wonder if conceptually meta argument is the right place to specify model configuration

I think you can add a custom attribute like arch, params.

parmeet commented 2 years ago

I feel the same dilemma.

have 4 different factory functions for each one of them

I think this can be useful if users want to build un-trained model.

That's right. An alternative would to to somehow inform the function not to load the weights (in which case the architecture config would still be able to construct the right model but with uninitialized weights). Though this might create confusion in user's mind because the idea with Weights is really to serve multiple pre-trained checkpoints. One alternative would be to call it ModelConstructs or something more generic instead of weights. Then it might be safe to provide additional keyword arguments in factory function like load_weights: bool. Though I really feel I am complicating things here, others may be able propose a better way of handling this issue :)

I wonder if conceptually meta argument is the right place to specify model configuration

I think you can add a custom attribute like arch, params.

datumbox commented 2 years ago

In Vision we have no choice but to have a builder per model variant. This is because we need to maintain BC, so that's not something we can change. In your cases, you might be able to keep things in a single builder method but that's an implementation detail that is in your control and beyond the scope of this proposal. Note that having one builder has pros/cons.

But the code is highly redundant since the only different here is the input configuration.

We don't actually have much redundant code. We typically end up calling a single private builder method. We just have multiple public interfaces for each supported method. See here.

I think this can be useful if users want to build un-trained model.

That is correct. If you have arch as a property of the weights entry then you need to introduce NoWeights enums. This is similar to the Alternative 1 of the RFC.

I think you can add a custom attribute like arch, params.

On the other hand, introducing a separate arch that is unrelated to the weights/params posses issues identical to the Alternative 2 of the RFC.

Then it might be safe to provide additional keyword arguments in factory function like load_weights: bool. Though I really feel I am complicating things here

Let's take a step back. The introduction of Enum weights aims to link a specific pre-trained model to its attributes (weights, transforms, recipe etc). Introducing both a weights and a load_weights bool is incompatible with the premise of this proposal and overcomplicates things. This is because you start using the Enum not as a way to describe the weights but as a way to describe the configuration of the architecture. This historically was a responsibility of the model builder in Vision and has served us very well. Though you have the option to adjust this on your end, I would advise against it due to the issues of alternative 1 & 2 linked above.

datumbox commented 2 years ago

Since the topic of attaching the model building config inside the weights keeps coming up in our discussions, I think it's worth writing in detail why I think this is not a good idea. Here by model building config, I refer to all the params passed to the constructor of the Model class to build it. It includes things like model hyper-parameters, layer configuration etc: https://github.com/datumbox/dapi-model-versioning/blob/c7f9302103267e968f86b8d7fe21931bdd820e3c/dapi_lib/models/resnet.py#L21-L28

As you recall each model builder method expects a specific Weights Enum for the specific model and that's how the two are associated. So instead of the weights data class storing the building config, we "store" this to the model builder method and just link to it:

https://github.com/datumbox/dapi-model-versioning/blob/c7f9302103267e968f86b8d7fe21931bdd820e3c/dapi_lib/models/resnet.py#L46-L47

https://github.com/datumbox/dapi-model-versioning/blob/c7f9302103267e968f86b8d7fe21931bdd820e3c/dapi_lib/models/resnet.py#L63

This is useful because:

  1. There is a many to 1 relationship between the two. A building config can associate with multiple Weights but each weight is associated with a single building config.
  2. It allows us to provide building configurations even when we don't offer fully trained weights. This can be useful when a model is very expensive to train. This is done by just passing None on the model weights of the builder. Since the information of how to build this is not in the weights but instead in the builder, we are able to do so without introducing fake Enum values, extra booleans etc.
  3. It builds upon existing practices and it's BC.

If you want to remove the model building config from the model builder method, one option is to store this information on its own on a separate Enum or Data class or dictionary. Here is what audio is currently doing: https://github.com/datumbox/dapi-model-versioning/blob/c7f9302103267e968f86b8d7fe21931bdd820e3c/dapi_lib/models/tacotron2.py#L17-L39

Though this can be an interesting idea, I think that this is beyond the scope of this RFC and should be handle separately. IMO this is something that is on the control of the Domain libraries and as long as this info is not dumped in the weights (for the reasons I explained above) they should be able to use a solution that meets their needs.