datumbox / dapi-model-versioning

RFC for Model Versioning across all PyTorch Domain libraries
2 stars 0 forks source link

Enum and dataclasses #2

Closed mthrok closed 3 years ago

mthrok commented 3 years ago

(Opening a new issue so as not to mix the ongoing discussion)

I feel that separating dataclass and Enum make the code more readable.

@dataclass
class ResNet50Weights(Enum):
    ImageNet1K_RefV1 = (
        url="https://path/to/weights.pth",  # Weights URL/path
        transforms=partial(ImageNetPreprocessing, width=224),  # Preprocessing transform constructor
        meta={"num_classes": 1000, "Acc@1": 76.130, "classes": [...]}  # Arbitrary Meta-Data
        # Other customizable fields go here
    )
@dataclass
class ResNet50Config:
    url: str
    transforms: Any
    meta: meta: Dict[str, Any]
    latest: bool

class ResNet50Weights(Enum):
    ImageNet1K_RefV1 = ResNet50Config(...)
    ...
datumbox commented 3 years ago

Thanks @mthrok.

Your question/proposal is valid just note that the code you are quoting comes from the README and its pseudocode, not real code. It was there for illustration purposes to keep things brief.

Here is the actual implementation:

@dataclass
class Weights(Enum):
    url: str
    transforms: Callable
    meta: Dict[str, Any]
    latest: bool

    # method definitions go here

And the here is the actual usage: https://github.com/datumbox/dapi-model-versioning/blob/21f0d8eb65d294790af3e38c81672f21e62a2784/dapi_lib/models/resnet.py#L46-L53

So as you see with this approach, the parameters are not named.

The alternative implementation would look like the following. Note that the use of the base Weights class is now optional as the methods can live in the flat namespace instead but for reasons I mention at https://github.com/datumbox/dapi-model-versioning/issues/1#issuecomment-928172738 I think it's still valuable to have it (I'm open to feedback, happy to change it if lots of people think it needs to go):

class Weights(Enum):
    # method definitions go here

@dataclass
class WeightEntry:
    url: str
    transforms: Any
    meta: meta: Dict[str, Any]
    latest: bool

Usage:


class ResNet50Weights(Weights):
    ImageNet1K_RefV1 = WeightEntry(url=..., transforms=..., meta=..., latest=True)
datumbox commented 3 years ago

@mthrok We are potentially having a duplicate discussion for this on #7. Shall we close this and continue the chat there?