patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.13k stars 143 forks source link

Equinox seems to break cooperative multiple inheritance #832

Closed NeilGirdhar closed 2 months ago

NeilGirdhar commented 2 months ago
from typing import override

import equinox as eqx
import jax.numpy as jnp
from jax import Array

class AcceptsStreams(eqx.Module):
    def __post_init__(self) -> None:
        if hasattr(super(), '__post_init__'):
            super().__post_init__()  # pyright: ignore

class InputNode(AcceptsStreams):
    total_value_error: Array = eqx.field(init=False)

    @override
    def __post_init__(self) -> None:
        print("PI InputNode")  # noqa: T201
        super().__post_init__()
        self.total_value_error = jnp.zeros(())

class DeductionSource(AcceptsStreams):
    pass

class DistillationBase(DeductionSource, AcceptsStreams):
    @override
    def __post_init__(self) -> None:
        print("PI DistillationBase")  # noqa: T201
        super().__post_init__()

class InputPerception(DistillationBase, InputNode):
    pass

print(", ".join(x.__qualname__ for x in InputPerception.__mro__))  # noqa: T201
# InputPerception, DistillationBase, DeductionSource, InputNode, AcceptsStreams, Module, object
InputPerception()
# ValueError: The following fields were not initialised during __init__: {'total_value_error'}

For some reason, InputNode.__post_init__ is never called. If DeductionSource is removed, then it is.

patrick-kidger commented 2 months ago

Thankyou for the report! This should be fixed in #834.

FWIW I generally recommend against co-operative multiple inheritance (I never see this done correctly except in code I own entirely), so these days I generally prefer the abstract/final pattern instead. IMO this is more readable + is certainly more robust. This is probably contributing to not having bumped into this edge-case before.

NeilGirdhar commented 2 months ago

MO this is more readable + is certainly more robust.

That would mean that every concrete class would need to duplicate the initialization code and init-variables of its superclasses. Besides the repeated code, it also breaks separation of concerns.

Personally, I think it's better to just use inheritance properly. It's a shame that Python inheritance confuses people.

I generally recommend against co-operative multiple inheritance (I never see this done correctly except in code I own entirely

It's pretty straightforward in Equinox. You define __post_init__ whenever you need it, and make sure that it calls super and forwards all keyword arguments.

All eqx.Modules can have:

This is probably contributing to not having bumped into this edge-case before.

Yeah, I realize that inheritance is unpopular :smile:

patrick-kidger commented 2 months ago

(For the record --

the initialization code and init-variables of its superclasses

I would argue that this shouldn't exist at all: an abstract class cannot be initialized by definition.)

NeilGirdhar commented 2 months ago

I would argue that this shouldn't exist at all: an abstract class cannot be initialized by definition.)

Right, you're basically pushing for what some people call "interface inheritance" only, which is a paradigm that many people consider to be very safe in any programming language.

Nothing wrong with that, but I'd like to have some classes that have data members too. Consider an optional base class that tracks errors. It overrides certain methods and saves its results in its member variables. If I did things with ABCs only, I'd have to add a bunch of code in every concrete class that inherits from it (at the very least to write to and initialize those member variables).