google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.98k stars 632 forks source link

Possible idea for Sequential implementation? #148

Closed danielsuo closed 4 years ago

danielsuo commented 4 years ago

Wasn't sure if there was some philosophical reason not to have Sequential (many of these points might apply!), but I found having this simple abstraction useful. It's...not the prettiest, I admit.

class Sequential(nn.Module):
    def apply(self, x, modules, args):
        results = x
        for module, arg in zip(modules, args):
            results = module(results, **arg)
        return results

And the way you might use is

model_def = Sequential.partial(modules=[Identity, Plus], args=[{}, {"z": 2}])
ys, params = model_def.init_by_shape(jax.random.PRNGKey(0), [(1, 2)])
model = nn.Model(model_def, params)
model(np.array([1,2]))

> DeviceArray([3, 4], dtype=int32)

As a complete example with the dummy modules:

class Identity(nn.Module):
    def apply(self, x):
        return x

class Plus(nn.Module):
    def apply(self, x, z):
        return x + z

class Sequential(nn.Module):
    def apply(self, x, modules, args):
        results = x
        for module, arg in zip(modules, args):
            results = module(results, **arg)
        return results

model_def = Sequential.partial(modules=[Identity, Plus], args=[{}, {"z": 2}])
ys, params = model_def.init_by_shape(jax.random.PRNGKey(0), [(1, 2)])
model = nn.Model(model_def, params)
model(np.array([1,2]))
danielsuo commented 4 years ago

We've also found the below module useful. Anyway, won't be offended at all if you guys don't like these; your reactions help us understand the design motivations better and we can always just keep them out of tree.

class Ensemble(nn.Module):
    def apply(self, x, modules, args):
        return [module(x, **arg) for (module, arg) in zip(modules, args)]
levskaya commented 4 years ago

So, I've written a ton of combinator / concatenative style NN code both in the original "stax" NN prototype as well as "trax" which also uses combinators for everything. Certainly combinator code ends up being the most concise way to construct NNs, and for straight-through models w. layers taking a single input and single output it's an elegant approach.

The real problem happens w. more complicated dataflow where you start needing to use Parallel combinators (what you call Ensemble there I think) and then start having to worry about nested tuple packing and unpacking or you switch to a stack-based inter-layer convention like traditional point-free "concatenative" languages. The problem with these conventions is that they lead to a "write-only" language - it's easy to make concise code, but the code isn't self-documenting any longer, all the dataflow is hidden from view and requires readers to bounce all over the layer libraries to reason-out what the heck is going on. That's why in flax we avoid such constructs for the most part, we prefer the slightly more explicit argument passing so that it's always clear what data's being passed in a single local read-through.

That said we're not ideological about such things and encourage our users to set up combinators if that's how they want to rig their code. I'm just not sure we want to encourage it by having combinators in-tree based on a lot of personal experience trying to maintain such code.

danielsuo commented 4 years ago

Thank you for the thorough explanation! Makes sense to me. Closing!