patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.12k stars 142 forks source link

pyright can't tell that Module fields are frozen #862

Closed garymm closed 1 month ago

garymm commented 1 month ago
from dataclasses import dataclass

import equinox as eqx

class A(eqx.Module):
    f: int | str

class B(A):
    f: str   # pyright complains here

@dataclass(frozen=True)
class ADC:
    f: int | str

class BDC(ADC):
    f: str

Pyright reports an error for B.f but not BDC.f.

"f" overrides symbol of same name in class "A"
  Variable is mutable so its type is invariant
    Override type "str" is not the same as base type "int | str" reportIncompatibleVariableOverride

I haven't investigated yet, but it seems that pyright is probably picking up this call to typing.dataclass_transform: https://github.com/patrick-kidger/equinox/blob/804d82e0fe65f85dd9f9a21f03d2784c164bbf8d/equinox/_module.py#L630

The docs say:

The decorated class, metaclass, or function may accept the following bool arguments which type checkers will assume have the same effect as they would have on the @dataclasses.dataclass decorator: init, eq, order, unsafe_hash, frozen, match_args, kw_only, and slots. It must be possible for the value of these arguments (True or False) to be statically evaluated.

Perhaps one fix would be for _ModuleMeta to accept a frozen argument that defaults to True and raise an exception if a user tries to set it to False.

There is a frozen_default arg to dataclass_transform. But this is only available in python >= 3.12.

I can't tell for sure, but this is maybe a back-port of the 3.12 version: https://github.com/Netflix/metaflow/blob/822449b8c910b83b0773eb9f814053fd971a47bb/metaflow/_vendor/typing_extensions.py#L2341

I guess the easiest possible fix is to set frozen_default=True and then wait for everyone to upgrade to Python 3.12.

If you LMK what you prefer I can take a stab at a PR.

patrick-kidger commented 1 month ago

I've tried this before. Unfortunately adding dataclass_transform(frozen=True) has a more serious break, in that the self.foo = bar assignments during __init__ all get flagged as errors.

Equinox does a hybrid approach where everything is frozen except during __init__ time. This has turned out to have great ergonomics, but there's no way to communicate that to static type checkers.

garymm commented 1 month ago

But in dataclasses self.foo = bar is allowed during __init__, right? So is that a bug in the type checker?

patrick-kidger commented 1 month ago

Nope. That's not allowed in frozen dataclasses. (There's been proposals to change this, it really reduces the usability of frozen dataclasses.)

garymm commented 1 month ago

OK, so I guess not feasible to fix this while preserving desired functionality.