datumbox / dapi-model-versioning

RFC for Model Versioning across all PyTorch Domain libraries
2 stars 0 forks source link

API extension to other areas of research or datasets #7

Closed vfdev-5 closed 3 years ago

vfdev-5 commented 3 years ago

A question on this weights API and its extension to other domains. For example, in medical AI people can train resnet50 on their medical datasets and probably would like to have something like

@dataclass
class ResNet50Weights(Enum):
    MedNist_RefV1 = (
        url="...",  # Weights URL/path
        # Other customizable fields go here
    )

without any imagenet or cifar10 weights.

A rather simple suggestion to such users could be to create something new like:

@dataclass
class MedResNet50Weights(Enum):
    MedNist_RefV1 = (
        url="...",  # Weights URL/path
        # Other customizable fields go here
    )

without any relationship to implemented ResNet50Weights referencing to ImageNet/Cifar10.

@datumbox what are your thoughts on that ?

mthrok commented 3 years ago

@datumbox This is one of the reason why I think dataclass and Enum should be separated. I think that Enum is used to facilitate querying into the information across multiple weight entries, but I think it is better separated from the signature of factory function.

Another example is speech recognition model with different languages. Typically, models for different languages are logically separated.

@dataclass
class SpeechRecognitionWeight:
    lang
    environment
    architecture

class English(Enum):
    EN_US = SpeechRecognitionWeight(lang='en_US'...)
    EN_GB = SpeechRecognitionWeight(lang='en_GB'...)

    def get_best() -> SpeechRecognitionWeight:
        ...

class French(Enum):
    FR_FR = SpeechRecognitionWeight(lang='fr_FR')
    FR_CA = SpeechRecognitionWeight(lang='fr_CA')

def factory_function(weight: Optional[SpeechRecognitionWeight] = None):
    ...

en_model = factory_function(English.EN_US.value)
fr_model = factory_function(French.FR_FR.value)
datumbox commented 3 years ago

@mthrok seems good.

This is a sizeable change I need to make all over the RFC and the code, so @vfdev-5 and @parmeet could you please let me know if you prefer to leave the implementation as is or adopt Moto's proposal?

EDIT: I modified the title of this issue to avoid using the word "Domains" which we commonly use for the domain libraries.

datumbox commented 3 years ago

I just realised that separating the dataclass from the Enum, as described by @mthrok above, removes a fundamental property of this API which is the 1-1 link between a specific model builder method and its weights class.

Current Approach

Currently it's impossible to do the following:

# Pass incorrect model weights to the wrong builder
model = resnet50(weights=ResNet101Weights.ImageNet1K_RefV1)

Because a) typing will complain and b) the model builder will through a runtime exception.

Option 1

After adopting the proposal things will look like:

@dataclass
class WeightEntry:
    url: str
    transforms: Callable
    meta: Dict[str, Any]
    latest: bool

    def state_dict(self, progress: bool) -> Dict[str, Any]:
        pass

class ResNet50Weights(Enum):
    ImageNet1K_RefV1 = WeightEntry(
        url="https://path/to/weights.pth",
        transforms=partial(ImageNetPreprocessing, width=224),
        meta={"num_classes": 1000, "Acc@1": 76.130, "classes": [...]},
        latest=True
    )
    CIFAR100_RefV1 = WeightEntry(
        # Weights data go here
    )

    @classmethod
    def get_latest(cls) -> List:
         pass

def resnet50(weights: Optional[WeightEntry] = None) -> ResNet:
    pass

Unfortunately now the builder can't tell that the wrong weights were loaded, the registration mechanism can't register automatically the Enum to the public API and mechanisms that rely on the builder knowing its weights type like getting the latest weights on registration will no longer be possible. It doesn't invalidate the entire proposal but it does remove a significant part of the proposed API.

Option 2

There is a middle ground solution to the above which allows us to separate the WeightEntry from the enum but it doesn't fix the use-case that @vfdev-5 brought up. Here is the variation:

class Weight(Enum):
    # Optional base class that contains weight specific methods.
    # It also serves as a way to identify a Weight Enum, which is useful for the registration API.
    # It's an internal detail and can be skipped if the DAPI lib doesn't need the above.
    @classmethod
    def check_type(cls, obj: Any) -> None:
         pass

    @classmethod
    def get_latest(cls) -> List:
         pass

@dataclass
class WeightEntry:
    url: str
    transforms: Callable
    meta: Dict[str, Any]
    latest: bool

    def state_dict(self, progress: bool) -> Dict[str, Any]:
        pass

class ResNet50Weights(Weight):
    ImageNet1K_RefV1 = WeightEntry(
        url="https://path/to/weights.pth",
        transforms=partial(ImageNetPreprocessing, width=224),
        meta={"num_classes": 1000, "Acc@1": 76.130, "classes": [...]},
        latest=True
    )
    CIFAR100_RefV1 = WeightEntry(
        # Weights data go here
    )

def resnet50(weights: Optional[ResNet50Weights] = None) -> ResNet:
    ResNet50Weights.check_type(weights)
    pass

Now we maintain the 1-1 mapping but separate the Enum from the Weight Entry. The Weights class is added to put inherited methods and to enable the registration mechanism to identify weights but can be considered optional. Unfortunately the problem with the proposal is that it still doesn't cover Victor's case.

I don't have an ultra strong opinion to keep the 1-1 mapping but I would like to hear more arguments against it. Moreover I think we should somehow show more benefits of option 1 if it is to be adopted. For Option 2, things are much simpler and the benefits are mainly on code-readability/aesthetics so I'm happy to make the change if you all agree. I would love to hear your thoughts.

datumbox commented 3 years ago

Note that multi language support (or any other attribute) is possible both in the current API and option 2.

You can keep all weights of all languages in a single enum:

class ModelXYZWeights(Enum):
    EN_US = ...
    EN_GB = ...
    FR_FR = ...
    FR_CA = ...

def modelxyz(weights: Optional[ModelXYZWeights] = None) -> ModelXYZ:
     pass

Or you can have a base class (so that the model builder knows the 1-1 mapping) and then subclass as many times as you want. Language specific info go directly in the enum entry, so you have full freedom to configure things as you want:


class ModelXYZWeights(Enum):
    pass

class EnglishModelXYZWeights(ModelXYZWeights):
    EN_US = ...
    EN_GB = ...

class FrenchModelXYZWeights(ModelXYZWeights):
    FR_FR = ...
    FR_CA = ...

def modelxyz(weights: Optional[ModelXYZWeights] = None) -> ModelXYZ:
     pass
mthrok commented 3 years ago

Combining the points from @parmeet in #9, having multiple factory function for different arch is cumbersome, especially with pretrained weights, I started to think it might be better to have a one function for pretrained models across the different arch, and leave the existing factory functions for instantiating untrained model with the ability to customize some parameters that are not architecture-related (like dropout param). I can give write some code snippet after getting off the train.

datumbox commented 3 years ago

@mthrok I know that things can be interconnected but to the degree possible, let's try not to split the same discussion across issues. I haven't managed to reply to the other issue but for Vision separate builders is absolutely crucial for BC. I'll reply to that in a bit but please let me know your thoughts on the above points.

Edit: I have now caught up and responded at https://github.com/datumbox/dapi-model-versioning/issues/9#issuecomment-932280040. Let's keep the discussion around multiple factory methods there to avoid missing information.

parmeet commented 3 years ago

@mthrok seems good.

This is a sizeable change I need to make all over the RFC and the code, so @vfdev-5 and @parmeet could you please let me know if you prefer to leave the implementation as is or adopt Moto's proposal?

EDIT: I modified the title of this issue to avoid using the word "Domains" which we commonly use for the domain libraries.

I think separating the Enum would allow to expand the scope of pre-trained model APIs in some interesting ways. For example, for use with local checkpoint. This would also allow to have default factory functions in WeightEntry: for instance for transforms. This is the common case for text, for example, different variant of XLMR uses same SPM tokenizer, so it would be nice to be able to somehow specify a default value that will be used across all the variants.

Assuming I have default transform factory function in WeightEntry, I could write something like this to load a local checkpoint..

my_local_weights = WeightEntry(url="path/to/local/model_weights.pt")
model = resnet50(weights=my_local_weights)
# continue training using model or use for evaluation

With current form it would look like

my_local_weights = ResNet50Weights.ImageNet1K_RefV1
my_local_weights.url = "path/to/local/model_weights.pt"
model = resnet50(weights=my_local_weights.weights)
# continue training using model or use for evaluation
datumbox commented 3 years ago

Sounds good. I'll separate the Enum from the data class.

At this point I'm going to go for a modified version of option 2 since it maintains the link between model builder and the Enum (which is crucial for some of the features we want to support at Vision) but also allows for passing directly custom WeightEntries. If such a link is not necessary in other libraries, you can declare your signatures as follows and this will allow you to receive custom WeightEntry records:

def resnet50(weights: Optional[Union[ResNet50Weights, WeightEntry]] = None) -> ResNet:
    ResNet50Weights.check_type(weights)
    pass

# Now you can actually do:
my_local_weights = WeightEntry(url="path/to/local/model_weights.pt")
model = resnet50(weights=my_local_weights)