Closed vfdev-5 closed 3 years ago
@datumbox This is one of the reason why I think dataclass
and Enum
should be separated.
I think that Enum
is used to facilitate querying into the information across multiple weight entries, but I think it is better separated from the signature of factory function.
Another example is speech recognition model with different languages. Typically, models for different languages are logically separated.
@dataclass
class SpeechRecognitionWeight:
lang
environment
architecture
class English(Enum):
EN_US = SpeechRecognitionWeight(lang='en_US'...)
EN_GB = SpeechRecognitionWeight(lang='en_GB'...)
def get_best() -> SpeechRecognitionWeight:
...
class French(Enum):
FR_FR = SpeechRecognitionWeight(lang='fr_FR')
FR_CA = SpeechRecognitionWeight(lang='fr_CA')
def factory_function(weight: Optional[SpeechRecognitionWeight] = None):
...
en_model = factory_function(English.EN_US.value)
fr_model = factory_function(French.FR_FR.value)
@mthrok seems good.
This is a sizeable change I need to make all over the RFC and the code, so @vfdev-5 and @parmeet could you please let me know if you prefer to leave the implementation as is or adopt Moto's proposal?
EDIT: I modified the title of this issue to avoid using the word "Domains" which we commonly use for the domain libraries.
I just realised that separating the dataclass from the Enum, as described by @mthrok above, removes a fundamental property of this API which is the 1-1 link between a specific model builder method and its weights class.
Currently it's impossible to do the following:
# Pass incorrect model weights to the wrong builder
model = resnet50(weights=ResNet101Weights.ImageNet1K_RefV1)
Because a) typing will complain and b) the model builder will through a runtime exception.
After adopting the proposal things will look like:
@dataclass
class WeightEntry:
url: str
transforms: Callable
meta: Dict[str, Any]
latest: bool
def state_dict(self, progress: bool) -> Dict[str, Any]:
pass
class ResNet50Weights(Enum):
ImageNet1K_RefV1 = WeightEntry(
url="https://path/to/weights.pth",
transforms=partial(ImageNetPreprocessing, width=224),
meta={"num_classes": 1000, "Acc@1": 76.130, "classes": [...]},
latest=True
)
CIFAR100_RefV1 = WeightEntry(
# Weights data go here
)
@classmethod
def get_latest(cls) -> List:
pass
def resnet50(weights: Optional[WeightEntry] = None) -> ResNet:
pass
Unfortunately now the builder can't tell that the wrong weights were loaded, the registration mechanism can't register automatically the Enum to the public API and mechanisms that rely on the builder knowing its weights type like getting the latest weights on registration will no longer be possible. It doesn't invalidate the entire proposal but it does remove a significant part of the proposed API.
There is a middle ground solution to the above which allows us to separate the WeightEntry from the enum but it doesn't fix the use-case that @vfdev-5 brought up. Here is the variation:
class Weight(Enum):
# Optional base class that contains weight specific methods.
# It also serves as a way to identify a Weight Enum, which is useful for the registration API.
# It's an internal detail and can be skipped if the DAPI lib doesn't need the above.
@classmethod
def check_type(cls, obj: Any) -> None:
pass
@classmethod
def get_latest(cls) -> List:
pass
@dataclass
class WeightEntry:
url: str
transforms: Callable
meta: Dict[str, Any]
latest: bool
def state_dict(self, progress: bool) -> Dict[str, Any]:
pass
class ResNet50Weights(Weight):
ImageNet1K_RefV1 = WeightEntry(
url="https://path/to/weights.pth",
transforms=partial(ImageNetPreprocessing, width=224),
meta={"num_classes": 1000, "Acc@1": 76.130, "classes": [...]},
latest=True
)
CIFAR100_RefV1 = WeightEntry(
# Weights data go here
)
def resnet50(weights: Optional[ResNet50Weights] = None) -> ResNet:
ResNet50Weights.check_type(weights)
pass
Now we maintain the 1-1 mapping but separate the Enum from the Weight Entry. The Weights class is added to put inherited methods and to enable the registration mechanism to identify weights but can be considered optional. Unfortunately the problem with the proposal is that it still doesn't cover Victor's case.
I don't have an ultra strong opinion to keep the 1-1 mapping but I would like to hear more arguments against it. Moreover I think we should somehow show more benefits of option 1 if it is to be adopted. For Option 2, things are much simpler and the benefits are mainly on code-readability/aesthetics so I'm happy to make the change if you all agree. I would love to hear your thoughts.
Note that multi language support (or any other attribute) is possible both in the current API and option 2.
You can keep all weights of all languages in a single enum:
class ModelXYZWeights(Enum):
EN_US = ...
EN_GB = ...
FR_FR = ...
FR_CA = ...
def modelxyz(weights: Optional[ModelXYZWeights] = None) -> ModelXYZ:
pass
Or you can have a base class (so that the model builder knows the 1-1 mapping) and then subclass as many times as you want. Language specific info go directly in the enum entry, so you have full freedom to configure things as you want:
class ModelXYZWeights(Enum):
pass
class EnglishModelXYZWeights(ModelXYZWeights):
EN_US = ...
EN_GB = ...
class FrenchModelXYZWeights(ModelXYZWeights):
FR_FR = ...
FR_CA = ...
def modelxyz(weights: Optional[ModelXYZWeights] = None) -> ModelXYZ:
pass
Combining the points from @parmeet in #9, having multiple factory function for different arch is cumbersome, especially with pretrained weights, I started to think it might be better to have a one function for pretrained models across the different arch, and leave the existing factory functions for instantiating untrained model with the ability to customize some parameters that are not architecture-related (like dropout param). I can give write some code snippet after getting off the train.
@mthrok I know that things can be interconnected but to the degree possible, let's try not to split the same discussion across issues. I haven't managed to reply to the other issue but for Vision separate builders is absolutely crucial for BC. I'll reply to that in a bit but please let me know your thoughts on the above points.
Edit: I have now caught up and responded at https://github.com/datumbox/dapi-model-versioning/issues/9#issuecomment-932280040. Let's keep the discussion around multiple factory methods there to avoid missing information.
@mthrok seems good.
This is a sizeable change I need to make all over the RFC and the code, so @vfdev-5 and @parmeet could you please let me know if you prefer to leave the implementation as is or adopt Moto's proposal?
EDIT: I modified the title of this issue to avoid using the word "Domains" which we commonly use for the domain libraries.
I think separating the Enum would allow to expand the scope of pre-trained model APIs in some interesting ways. For example, for use with local checkpoint. This would also allow to have default factory functions in WeightEntry: for instance for transforms. This is the common case for text, for example, different variant of XLMR uses same SPM tokenizer, so it would be nice to be able to somehow specify a default value that will be used across all the variants.
Assuming I have default transform factory function in WeightEntry, I could write something like this to load a local checkpoint..
my_local_weights = WeightEntry(url="path/to/local/model_weights.pt")
model = resnet50(weights=my_local_weights)
# continue training using model or use for evaluation
With current form it would look like
my_local_weights = ResNet50Weights.ImageNet1K_RefV1
my_local_weights.url = "path/to/local/model_weights.pt"
model = resnet50(weights=my_local_weights.weights)
# continue training using model or use for evaluation
Sounds good. I'll separate the Enum from the data class.
At this point I'm going to go for a modified version of option 2 since it maintains the link between model builder and the Enum (which is crucial for some of the features we want to support at Vision) but also allows for passing directly custom WeightEntries. If such a link is not necessary in other libraries, you can declare your signatures as follows and this will allow you to receive custom WeightEntry
records:
def resnet50(weights: Optional[Union[ResNet50Weights, WeightEntry]] = None) -> ResNet:
ResNet50Weights.check_type(weights)
pass
# Now you can actually do:
my_local_weights = WeightEntry(url="path/to/local/model_weights.pt")
model = resnet50(weights=my_local_weights)
A question on this weights API and its extension to other domains. For example, in medical AI people can train resnet50 on their medical datasets and probably would like to have something like
without any imagenet or cifar10 weights.
A rather simple suggestion to such users could be to create something new like:
without any relationship to implemented
ResNet50Weights
referencing to ImageNet/Cifar10.@datumbox what are your thoughts on that ?