Open srush opened 2 years ago
I would like to have their setup be batched by Flax linen. So for example in setup I could call.
Scan isn't batching but looping. It will generate an XLA loop which mainly saves compilation time but it does block some optimization and can have significant overhead.
Another problem could be that it isn't exactly doing what you expect. SeqInternal now becomes the scan body (with different params in each iteration) but you do have to take a carry argument and return a tuple of (carry, scan_outputs). Did you handle this in SeqInternal?
Btw, we are planning to add a version of scan that acts as a kind of Sequential combinator so you don't have to write this scan body. cc @levskaya
Hi @jheek
Thanks for the response. Yes, I setup seqinternal to have a carry and a null scan outputs.
When I say batching, I mean I want my setup
method to be batched. In this particular example all of my SeqInternal
have a complex setup, and a fast call
. I know the call
needs to be sequential but I feel like the setup can and should be batched.
I don't mind the extra compile time, but it seemed to just run very slowly even after the first pass. Not sure if I accidentally compiled something in, or if it was re-jitting each time?
Here's an example, adapted from https://srush.github.io/annotated-s4/#an-ssm-neural-network.
class SeqInternal(nn.Module):
def setup(self):
self.B = self.param("B", lecun_normal(), (self.N, 1))
# would love this be vmap'ped on bind
self.K = slowfunction(self.B)
def __call__(self, c, ignore):
# pretty cheap function, structured to work with scan..
return f(self.w, c) , None
I don't think there's a super clean solution for this right now.
cc @levskaya, I think this is another argument for scan-over-layers. Because the carry is forced to be statically shaped and the activations are normally discarded during init we can make scan-over-layers a vmap-over-layers during init. Probably behind an optional kwarg
The quick hack to do this would be:
class SeqInternalStack(nn.Module):
n: int # num layers
@nn.compact
def __call__(self, x):
def init_fn(rng):
return jax.vmap(SeqInternal().init, in_axes=(0, None))(random.split(rng, self.n, x)
stack_params = self.param("stack", init_fn)
c, _ = lax.scan(lambda x, params: SeqInternal().apply(params, x), x, stack_params)
return c
Neat thanks! I'll have to parse a bit why this hack works, but its neat that you can do it.
See our design note on lifting: https://flax.readthedocs.io/en/latest/design_notes/lift.html for this trick (and it's limitations)
Description of the model to be implemented
If I have 6 layers stacked of a model (think transformer style), I would like to have their setup be batched by Flax linen. So for example in
setup
I could call.However it seems more natural to do this with scan.
In theory (?) this would batch the setup call for each of my SeqInternal layers.
However when I try to do this, things are really slow. Is there an issue with spliting on params for scan? I don't see it in any examples.