Closed phoenix-meadowlark closed 2 years ago
Hi @phoenix-meadowlark sorry for the late reply! I'm personally fine adding them if they make using Flax easier and they don't complicate our API, which in this case doesn't seem to be happening.
I'd also like to hear what @jheek thinks of this.
Currently we only define something as a Module if it actually has variables or rngs. So in this case flatten, global_avg_pool, and zero_pad would become functions and not Modules (this makes them more widely applicable). They very closely match existing NumPy functions. So far our philosophy has been to write code that looks like numpy more than make code that looks like PyTorch.
For DepthwiseConv I'm not quite sure. Using feature_group_count is a nice parameterization because it extends to group convolutions as well. Again I think this depends mostly on what people are used to. An alternative would be to make an explicit example pattern in the Conv
docstring on how to do depthwise (and/or grouped) convolutions.
Closing this due to inactivity.
I had to translate some TensorFlow models into JAX and found the following layers didn't exist:
nn.flatten
: In TensorFlow astf.keras.layers.Flatten
and Torch asnn.Flatten
. Could also be implemented withstart_dim
andend_dim
as in Torch.nn.global_avg_pool
: In TensorFlow astf.keras.layers.GlobalAveragePoolingND
(and not in Torch).nn.ZeroPad
: In TensorFlow astf.keras.layers.ZeroPaddingND
and Torch asnn.ZeroPad2d
nn.DepthwiseConv
: In TensorFlow astf.keras.layers.DepthwiseConv2d
and in a note in the Torchnn.Conv
documentation. Could also be 'implemented' as a note in the documentation, but an actual implementation is quite simple as well.Would be happy to add any of these and/or the architecture implementations of the MobileNet models as examples.