patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
1.94k stars 132 forks source link

Abstract/final pattern does not support abstract __init__ #647

Open mjo22 opened 6 months ago

mjo22 commented 6 months ago

The following is not allowed in the abstract/final pattern

import abc
import equinox as eqx

some_function = lambda arg1: ...

class AbstractModule(eqx.Module, strict=True):

    @abc.abstractmethod
    def __init__(arg1, arg2):
        raise NotImplementedError

    @classmethod
    def some_custom_constructor(cls, arg3):
        # Compute arg1 and arg2 with arg3
        return cls(arg1, arg2)

class Module1(AbstractModule, strict=True):

    arg1: Array
    arg2: Array

    @override
    def __init__(arg1, arg2):
        self.arg1 = arg1
        self.arg2 = arg2

class Module1(AbstractModule, strict=True):

    not_arg1: Array
    arg2: Array

    @override
    def __init__(arg1, arg2):
        self.not_arg1 = some_function(arg1)
        self.arg2 = arg2

Here, the custom constructor requires that the __init__ be of a certain form. However, one may not necessarily want set the same exact fields even if the arguments to the __init__ are the same. Is there a reason this does not fall into the abstract/final design pattern? Or should this be supported with strict = True?

patrick-kidger commented 6 months ago

I think this should be supported! I'd be happy to take a PR adjusting this.

mjo22 commented 6 months ago

Okay sounds good! I will try to find some time to put one together.

mjo22 commented 4 months ago

Hello, wanted to update you that unfortunately I don’t think I’ll be able to get to this sometime soon (sorry about that). My solution in the meantime is to just relax my strict=True requirement.