patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.12k stars 142 forks source link

pylint unable to detect ``__init__`` signature correctly when base class has ``__init__`` #891

Open ji8er opened 2 weeks ago

ji8er commented 2 weeks ago

I have a pattern where I have a single abstract base class with an __init__. Multiple classes derive from it, overriding some other method of the abstract base class.

Although I am able to instantiate objects correctly, pylint seems to show the squiggly red line below the second argument in the object instantiation call. For example, see below:

import abc
import jax
import jax.numpy as jnp
import jaxtyping as jt
import equinox as eqx

class AbstractClass(eqx.Module, strict=True):
  w1: jt.Float[jt.Array, 'nk']

  def __init__(self, a, b):
    self.w1 = jnp.zeros((a, b))

  @abc.abstractmethod
  def __call__(self, *args, **kwargs):
    raise NotImplementedError

class ConcreteClass(AbstractClass, strict=True):
  def __call__(self, to_add):
    return jnp.sum(self.w1)+to_add

c = ConcreteClass(2, 33333)  # [pylint shows UNEXEPECTED POSITIONAL ARGUMENT]
c(4)

I wonder why this is the case, and if there's a way to resolve?

Thanks for the great package, btw!

patrick-kidger commented 2 weeks ago

So by default, dataclasses actually have rather unusual behaviour around __init__ methods: they actually always automatically create a new one in the subclass, even if it would override one that exists in the base class. I found that this felt bad ergonomically, in precisely the use-case you have here! As such I disabled this behaviour, so that Equinox will only automatically create an __init__ method if no parent class has a user-provided one.

In this case, I suspect pylint is picking up on the usual dataclass behaviour. As such I think the best thing to do is simply to add the appropriate comment to disable pylint from raising the error on the line in which it is raised.

Taking a step back at Equinox's design here: we have precisely two divergences from dataclasses, both of them in the __init__ method. The first is the one mentioned above. The second one is that we allow mutating self during __init__, but not afterwards. In contrast, normal frozen dataclasses will throw an error if you try doing self.w1 = ...! I judged the ergonomic benefits of these two points to be worth the infrequent need to disable a static lint or typing error.

Thanks for the great package, btw!

Thank you!