stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
141 stars 9 forks source link

Why init instead of __init__? #17

Closed rohan-mehta-1024 closed 11 months ago

rohan-mehta-1024 commented 11 months ago

I was wondering why it's considered idiomatic in haliax to instantiate new modules with a @staticmethod init instead of using the built in object initializer, __init__? The latter seemed more natural to me so I've been using it instead when defining my own modules. Nothing has broken so far, but now when I try to define a transformer block:

class TBlock(eqx.Module):
    ffn: FeedForward
    attn: Attention

    def __init__(self, key):
        k1, k2 = jr.split(key, 2)
        self.ffn = FeedForward(Embed, Mlp, key)
        self.attn = Attention(Key, Embed, Head, key)

    def __call__(self, x: hax.NamedArray, padding_mask) -> hax.NamedArray:
        x = x + self.attn(x, padding_mask)
        x = x + self.ffn(x)
        return x

and then use stacked with it

self.blocks = hnn.Stacked.init(Layer, TBlock)(key=jr.split(k3, Layer.size))

I get AttributeError: type object 'TBlock' has no attribute 'init'

presumably because stacked was expecting a

@staticmethod
def init(self):
    pass

instead of my

def __init__(self):
    pass

Is there any way to get this to work with the typical idiomatic Python initialization? And why was this choice made in the library design (since, as far as I can tell, Equinox also is in keeping with traditional Python here)?

dlwh commented 11 months ago

In general I wanted there to be a "boring" constructor that just takes the fields, which helps making initializing from other sources easier: loading from HF checkpoints, different initialization strategies, etc.

We could revisit, but that's why I did it the way I did

rohan-mehta-1024 commented 11 months ago

Ah ok, the more I thought about this the more it actually probably seems like the right idea. I'll change my code to be consistent with it. Have not encountered it before in ML libraries, but it does make sense to be able, e.g., to initialize a module with a very specific weight matrix or something like that. Incidentally, this strengthens my belief that haliax could be used to make a really good interpretability library!