vocalpy / vak

A neural network framework for researchers studying acoustic communication
https://vak.readthedocs.io
BSD 3-Clause "New" or "Revised" License
78 stars 16 forks source link

BUG/CLN: Refactor model abstraction so we don't subclass LightningModel, to fix loss logging #737

Closed NickleDave closed 6 months ago

NickleDave commented 10 months ago

When we subclass LightningModel it seems to break logging, see #726.

I am at least able to get better logging -- still does some weird things -- if I remove all the magic sub-classing of vak.models.base.Model and instead define model families that each separately sub-class LightingModule, e.g. FrameClassificationModel(lightning.LightningModule).

This actually can be fine for us; we define a per-family class, and we refactor the logic we have now for converting definitions into models such that it instantiates the components of a model and then passes those components into the model-family class when instantiating it.

NickleDave commented 6 months ago

I think for now we can fix this by changing vak.models.base.Model so it does not sub-class LightningModule, and instead just gives us basically a dataclass instance with the attributes we want that we can pass in to model-family-specific LightningModules.

My read of this code now is that I went in over my head on metaprogramming and "came in ass-first thinking [I] invented sliced bread" (to paraphrase Andy Partridge). Not at all obvious to me why I can't un-spaghetti this code so that it is just a dataclass, e.g. we get back a ModelDefinition instance with the attributes we want on a per-model basis.

NickleDave commented 6 months ago

I think a possible fix is to do something like this

def model(modeldef: class, family: LightningModule):
    """Decorator that creates a new class
    representing a model that belongs to a family of models,
    given a class representing the definition of the model
    and the name of the family.
    """
    definition.validate(modeldef)
    is_valid_family(family)

    class Model:
        # class variables
        modeldef = modeldef
        family = family

        def from_config_dict(self, config_dict) -> LightningModule:
            network, optimizer, loss, metrics = from_attributes(self.model)
            return self.family(
                network,
                optimizer,
                loss,
                metrics,
            )

    Model = functools.update_wrapper(Model, modeldef)
    return Model()

It's weird that we return an instance of a class we just defined that has a single method from_model_config ... but this seems like the easiest shim I could put in place right now