Hello guys, first of all, thanks for your awesome work. As a TF 0.6 former user (long time ago), I find Jax/Flax really great. No more hidden magic such as "compile, fit, close your eyes".
Now, the thing is I use siamese nets and triplet losses a lot - every day. And since I come from a C/C++ background, I usually try to template things.
I went through the discussion here https://github.com/google/flax/issues/208 (major evolution of API) , I think it makes sense to get rid of the mix between classes, subclasses, class instances and objects, since the gotchas between "init", "call", "shared" and "partial" are not obvious at first glance, despite the nice documentation.
The purpose of the following is just to present the the way a twisted and old mind like mine is using flax. Maybe it can fuel the discussion or be part of an advanced example in the documentation, or be an example of "things you don't want to do with flax". I think that the template issue is a common one, and I would really enjoy a piece of advice on the most canonical way to do this in Flax.
Example of code
I start with a simple layer :
class MyCNNBlock(nn.Module):
def apply(self, x, n_filters=32,kernel_size=3, dtype=jnp.float32, train=False):
x = nn.Conv(x, n_filters, (kernel_size, kernel_size), (1, 1),
padding='SAME', bias=True, dtype=dtype)
x = nn.BatchNorm(x, use_running_average=not train,
momentum=0.9, epsilon=1e-5, dtype=dtype)
x = jnp.maximum(x, 0)
return x
And my template code looks like this. Note that one of the keywords in apply function is a nn.Module subclass.
class MySiameseNetTemplate(nn.Module):
def apply(self, x, module=None,train=False, shared=True):
assert issubclass(module,nn.Module)
# cut input in two
x1, x2 = jnp.split(x, 2, axis=3)
if shared:
base = module.shared(train=train)
y1 = base(x1)
y2 = base(x2)
else:
y1 = module(x1, train=train)
y2 = module(x2, train=train)
return y1, y2
Then when I create an instance of the template module, I use partial since it allows normalizing the signature :
An other template I use is the following sequence template.
class MySequenceNetTemplate(nn.Module):
def apply(self, x, list_of_modules=None,train=False):
for m in list_of_modules:
assert issubclass(m, nn.Module)
x = m(x,train=train)
return x
What about serialization / export of network structure / hpparams ?
In order to solve the issue of serialization of the network structure, since I want to dump it in alongside with the optimizer state checkpoints, I wrote a lightweight factory template. Basically it allows falling back on regular "OOP" for serialization purposes.
The basic objective is to rely only on serialization of pure python object in yaml. This is of primary importance to me since I generate production networks and I need to keep track of the configuration using git and I need to avoid complex types issues that jeopardize the ability to refactor the code base / API (for instance pickling stuff forbids renaming classes).
Note that I would certainly not expect flax to handle this part. I red the discussion regarding hpparamshttps://github.com/google/flax/discussions/194 with interest and I do agree that KISS is the way to go. Depending on its use cases each end user will end up with its own way of handling hpparams structure. Again, the sole reason for the piece of code that follow is just to fuel discussions.
The abstract class looks like this
class MyModuleDescription:
def create_partial_module(self, train=False):
raise NotImplementedError
def export_to_dict():
# should return dict of simple python types
raise NotImplementedError
def export_to_yaml():
# use marshmallow for instance (not my case, I will not elaborate on this)
raise NotImplementedError
And two example of concrete classes:
class MyCNNBlockDescription(MyModuleDescription):
def __init__(self,
n_filters=3,
kernel_size = 3
use_bn=True,
name=None):
super(MyModuleDescription, self).__init__()
self.n_filters = n_filters
self.kernel_size = kernel_size
self.use_bn = use_bn
self.name = name
def create_partial_module(self, train=False, name=None):
if name is None:
name = self.name
return MyCNNBlock.partial(n_filters=self.n_filters,
kernel_size=self.kernel_size,
use_bn=self.use_bn,
train=train,
name=name)
def export_to_dict():
# simple
return self.__dict__
def export_to_yaml():
# quite obvious here, only leafs
And for the sequence template :
class MySequenceDescription(MyModuleDescription):
def __init__(self, list_descriptions=None, name=None):
super(MyModuleDescription, self).__init__()
for f in list_descriptions:
assert isinstance(f, MyModuleDescription)
self.list_descriptions = list_descriptions
self.name = name
def create_partial_module(self, train=False, name=None):
if name is None:
name = self.name
list_m = []
for f in self.list_factories:
assert isinstance(f, MyModuleDescription)
list_m.append(f.create_partial_module(train=train))
return MySequenceNetTemplate.partial(list_module=list_m, train=train,name=name)
def export_to_dict():
# do nested stuff...
def export_to_yaml():
# do nested stuff here with recursive call to export_to_dict()
Conclusion
I am eager to get some comments on how to make all of this more flaxy. Any comments or critics on how core flax devs would have written this stuff would be greatly appreciated ! Again, congrats to all the XLA, Jax and Flax developers, your work is great.
Hello guys, first of all, thanks for your awesome work. As a TF 0.6 former user (long time ago), I find Jax/Flax really great. No more hidden magic such as "compile, fit, close your eyes".
Now, the thing is I use siamese nets and triplet losses a lot - every day. And since I come from a C/C++ background, I usually try to template things.
I went through the discussion here https://github.com/google/flax/issues/208 (major evolution of API) , I think it makes sense to get rid of the mix between classes, subclasses, class instances and objects, since the gotchas between "init", "call", "shared" and "partial" are not obvious at first glance, despite the nice documentation.
The purpose of the following is just to present the the way a twisted and old mind like mine is using flax. Maybe it can fuel the discussion or be part of an advanced example in the documentation, or be an example of "things you don't want to do with flax". I think that the template issue is a common one, and I would really enjoy a piece of advice on the most canonical way to do this in Flax.
Example of code
I start with a simple layer :
And my template code looks like this. Note that one of the keywords in
apply
function is ann.Module
subclass.Then when I create an instance of the template module, I use
partial
since it allows normalizing the signature :An other template I use is the following sequence template.
What about serialization / export of network structure / hpparams ?
In order to solve the issue of serialization of the network structure, since I want to dump it in alongside with the optimizer state checkpoints, I wrote a lightweight factory template. Basically it allows falling back on regular "OOP" for serialization purposes.
The basic objective is to rely only on serialization of pure python object in yaml. This is of primary importance to me since I generate production networks and I need to keep track of the configuration using git and I need to avoid complex types issues that jeopardize the ability to refactor the code base / API (for instance pickling stuff forbids renaming classes).
Note that I would certainly not expect flax to handle this part. I red the discussion regarding
hpparams
https://github.com/google/flax/discussions/194 with interest and I do agree that KISS is the way to go. Depending on its use cases each end user will end up with its own way of handling hpparams structure. Again, the sole reason for the piece of code that follow is just to fuel discussions.The abstract class looks like this
And two example of concrete classes:
And for the sequence template :
Conclusion
I am eager to get some comments on how to make all of this more flaxy. Any comments or critics on how core flax devs would have written this stuff would be greatly appreciated ! Again, congrats to all the XLA, Jax and Flax developers, your work is great.