google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.09k stars 644 forks source link

Best practice : template module in flax ? #457

Closed mortner31 closed 4 years ago

mortner31 commented 4 years ago

Hello guys, first of all, thanks for your awesome work. As a TF 0.6 former user (long time ago), I find Jax/Flax really great. No more hidden magic such as "compile, fit, close your eyes".

Now, the thing is I use siamese nets and triplet losses a lot - every day. And since I come from a C/C++ background, I usually try to template things.

I went through the discussion here https://github.com/google/flax/issues/208 (major evolution of API) , I think it makes sense to get rid of the mix between classes, subclasses, class instances and objects, since the gotchas between "init", "call", "shared" and "partial" are not obvious at first glance, despite the nice documentation.

The purpose of the following is just to present the the way a twisted and old mind like mine is using flax. Maybe it can fuel the discussion or be part of an advanced example in the documentation, or be an example of "things you don't want to do with flax". I think that the template issue is a common one, and I would really enjoy a piece of advice on the most canonical way to do this in Flax.

Example of code

I start with a simple layer :


class MyCNNBlock(nn.Module):

    def apply(self, x, n_filters=32,kernel_size=3, dtype=jnp.float32, train=False):

        x = nn.Conv(x, n_filters, (kernel_size, kernel_size), (1, 1),
                    padding='SAME',  bias=True, dtype=dtype)

        x = nn.BatchNorm(x, use_running_average=not train,
                             momentum=0.9, epsilon=1e-5,  dtype=dtype)

        x = jnp.maximum(x, 0)

        return x

And my template code looks like this. Note that one of the keywords in apply function is a nn.Module subclass.

class MySiameseNetTemplate(nn.Module):

    def apply(self, x, module=None,train=False, shared=True):

        assert issubclass(module,nn.Module)

        # cut input in two
        x1, x2 = jnp.split(x, 2, axis=3)

        if shared:
            base = module.shared(train=train)
            y1 = base(x1)
            y2 = base(x2)
        else:
            y1 = module(x1, train=train)
            y2 = module(x2, train=train)

        return y1, y2

Then when I create an instance of the template module, I use partial since it allows normalizing the signature :

def create_model(prng_key, use_bn=True, shared=True):

    input_shape = (100, 64, 64, 2)
    model_dtype = jnp.float32

    module_base = MyCNNBlock.partial(use_bn=use_bn,n_filters=64)
    module = MySiameseNetTemplate.partial(train=True, module = module_base,shared=shared)

    with nn.stateful() as init_state:
        with flax.nn.stochastic(prng_key):
            _, initial_params = module.init_by_shape(
                prng_key, [(input_shape, model_dtype)])
            model = nn.Model(module, initial_params)

    return model, init_state

An other template I use is the following sequence template.

class MySequenceNetTemplate(nn.Module):

    def apply(self, x, list_of_modules=None,train=False):

            for m in list_of_modules:
                 assert issubclass(m, nn.Module)
                 x = m(x,train=train)  

        return x

What about serialization / export of network structure / hpparams ?

In order to solve the issue of serialization of the network structure, since I want to dump it in alongside with the optimizer state checkpoints, I wrote a lightweight factory template. Basically it allows falling back on regular "OOP" for serialization purposes.

The basic objective is to rely only on serialization of pure python object in yaml. This is of primary importance to me since I generate production networks and I need to keep track of the configuration using git and I need to avoid complex types issues that jeopardize the ability to refactor the code base / API (for instance pickling stuff forbids renaming classes).

Note that I would certainly not expect flax to handle this part. I red the discussion regarding hpparams https://github.com/google/flax/discussions/194 with interest and I do agree that KISS is the way to go. Depending on its use cases each end user will end up with its own way of handling hpparams structure. Again, the sole reason for the piece of code that follow is just to fuel discussions.

The abstract class looks like this

class MyModuleDescription:

    def create_partial_module(self, train=False):
        raise NotImplementedError

    def export_to_dict():
        # should return dict of simple python types
        raise NotImplementedError

    def export_to_yaml():
        # use marshmallow for instance (not my case, I will not elaborate on this)
        raise NotImplementedError

And two example of concrete classes:

class MyCNNBlockDescription(MyModuleDescription):

    def __init__(self,
                 n_filters=3,
                 kernel_size = 3
                 use_bn=True,
                 name=None):

        super(MyModuleDescription, self).__init__()

        self.n_filters = n_filters
        self.kernel_size = kernel_size
        self.use_bn = use_bn
        self.name = name

    def create_partial_module(self, train=False, name=None):
        if name is None:
            name = self.name

        return MyCNNBlock.partial(n_filters=self.n_filters,
                                    kernel_size=self.kernel_size,
                                    use_bn=self.use_bn,
                                    train=train,
                                    name=name)

     def export_to_dict():
       # simple
       return self.__dict__

     def export_to_yaml():
        # quite obvious here, only leafs      

And for the sequence template :

class MySequenceDescription(MyModuleDescription):

    def __init__(self, list_descriptions=None, name=None):

        super(MyModuleDescription, self).__init__()

        for f in list_descriptions:
            assert isinstance(f, MyModuleDescription)

        self.list_descriptions = list_descriptions
        self.name = name

    def create_partial_module(self, train=False, name=None):
        if name is None:
            name = self.name

        list_m = []
        for f in self.list_factories:
            assert isinstance(f, MyModuleDescription)
            list_m.append(f.create_partial_module(train=train))

        return MySequenceNetTemplate.partial(list_module=list_m, train=train,name=name)

    def export_to_dict():
        # do nested stuff... 

    def export_to_yaml():
        # do nested stuff here with recursive call to export_to_dict() 

Conclusion

I am eager to get some comments on how to make all of this more flaxy. Any comments or critics on how core flax devs would have written this stuff would be greatly appreciated ! Again, congrats to all the XLA, Jax and Flax developers, your work is great.

mortner31 commented 4 years ago

I close, this is outdated, i will try the linen api first.