drivendataorg / zamba

A Python package for identifying 42 kinds of animals, training custom models, and estimating distance from camera trap videos
https://zamba.drivendata.org/docs/stable/
MIT License
118 stars 27 forks source link

Add hardware_dependent_fields to ZambaBaseModel's Config #154

Open pjbull opened 3 years ago

pjbull commented 3 years ago

Currently, we publish configs with our official models so that it is easy for other people to use them. We do this by removing the properties that are hardware dependent (e.g., device: cuda or batch_size: 4) and resaving the files. This happens in a function that does the work for the configs that are passed:

https://github.com/drivendataorg/zamba/blob/33c1c2a20feec695139f9c6af2ce05d67ed291a0/zamba/models/publish_models.py#L15-L49

We can simplify and generalize this code by adding a hardware_dependent_fields property to the config and then a hardware_independent_dict method to the model that will serialize the config, removing all the hardware_dependent_fields on that model and any children that inherit from the same base model. Here's a working implementation with a small example config:

from pydantic import BaseModel

class CustomBase(BaseModel):
    class Config:
        hardware_dependent = []

    def dict_hardware_independent(self):
        full_dict = self.dict()

        # remove hardware dependent fields on children
        for field_name, field_type in self.__fields__.items():
            field_value = getattr(self, field_name, None)
            if isinstance(field_value, CustomBase):
                full_dict[field_name] = field_value.dict_hardware_independent()

        # remove hardware dependent fields on this model
        for f in self.Config.hardware_dependent:
            full_dict.pop(f)

        return full_dict

class SubModelC(CustomBase):
    x: str

    class Config:
        hardware_dependent = ['x']

class SubModelA(CustomBase):
    a: str
    b: str
    c: SubModelC

    class Config:
        hardware_dependent = ['a']

instance_a = SubModelA(a="a", b="b", c=SubModelC(x='x'))

instance_a.dict_hardware_independent()
# > {'b': 'b', 'c': {}}