google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.97k stars 631 forks source link

Scan with split parameters #1771

Open srush opened 2 years ago

srush commented 2 years ago

Description of the model to be implemented

If I have 6 layers stacked of a model (think transformer style), I would like to have their setup be batched by Flax linen. So for example in setup I could call.

        self.layers = [SeqInternal() for _ in range(self.n_layers)]  

However it seems more natural to do this with scan.

         SeqInternalStack = nn.scan(SeqInternal,                                                                                             
                                    split_rngs={"params" : True, "dropout": True},                                                           
                                    in_axes=0,                                                                                               
                                    out_axes=0,                                                                                              
                                    variable_axes={"params": 0},                                                                             
                                    length=self.n_layers)                                                                                            
        self.layers = SeqInternalStack() 

In theory (?) this would batch the setup call for each of my SeqInternal layers.

However when I try to do this, things are really slow. Is there an issue with spliting on params for scan? I don't see it in any examples.

jheek commented 2 years ago

I would like to have their setup be batched by Flax linen. So for example in setup I could call.

Scan isn't batching but looping. It will generate an XLA loop which mainly saves compilation time but it does block some optimization and can have significant overhead.

Another problem could be that it isn't exactly doing what you expect. SeqInternal now becomes the scan body (with different params in each iteration) but you do have to take a carry argument and return a tuple of (carry, scan_outputs). Did you handle this in SeqInternal?

Btw, we are planning to add a version of scan that acts as a kind of Sequential combinator so you don't have to write this scan body. cc @levskaya

srush commented 2 years ago

Hi @jheek

Thanks for the response. Yes, I setup seqinternal to have a carry and a null scan outputs.

When I say batching, I mean I want my setup method to be batched. In this particular example all of my SeqInternal have a complex setup, and a fast call. I know the call needs to be sequential but I feel like the setup can and should be batched.

I don't mind the extra compile time, but it seemed to just run very slowly even after the first pass. Not sure if I accidentally compiled something in, or if it was re-jitting each time?

srush commented 2 years ago

Here's an example, adapted from https://srush.github.io/annotated-s4/#an-ssm-neural-network.

class SeqInternal(nn.Module):
    def setup(self):
        self.B = self.param("B", lecun_normal(), (self.N, 1))
        # would love this be vmap'ped on bind
        self.K = slowfunction(self.B)

    def __call__(self, c, ignore):
        # pretty cheap function, structured to work with scan..
        return f(self.w, c) , None
jheek commented 2 years ago

I don't think there's a super clean solution for this right now.

cc @levskaya, I think this is another argument for scan-over-layers. Because the carry is forced to be statically shaped and the activations are normally discarded during init we can make scan-over-layers a vmap-over-layers during init. Probably behind an optional kwarg

The quick hack to do this would be:

class SeqInternalStack(nn.Module):
   n: int # num layers

   @nn.compact
    def __call__(self, x):
        def init_fn(rng):
          return jax.vmap(SeqInternal().init, in_axes=(0, None))(random.split(rng, self.n, x)
        stack_params = self.param("stack", init_fn)
        c, _ = lax.scan(lambda x, params: SeqInternal().apply(params, x), x, stack_params)
        return c
srush commented 2 years ago

Neat thanks! I'll have to parse a bit why this hack works, but its neat that you can do it.

jheek commented 2 years ago

See our design note on lifting: https://flax.readthedocs.io/en/latest/design_notes/lift.html for this trick (and it's limitations)