google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.07k stars 642 forks source link

Expanding flax.linen with common ML primitives. #1205

Closed phoenix-meadowlark closed 2 years ago

phoenix-meadowlark commented 3 years ago

I had to translate some TensorFlow models into JAX and found the following layers didn't exist:

def flatten(inputs: Array) -> Array:
  """Flattens 'inputs', preserving the batch dim."""
  return inputs.reshape((inputs.shape[0], -1))
def global_avg_pool(inputs: Array) -> Array:
  """Averages over the spacial dims of 'inputs'."""
  return jnp.mean(inputs, axis=list(range(1, len(inputs.shape) - 1)))
class ZeroPad(nn.Module):
  padding: Sequence[Union[int, Sequence[int]]]

  @nn.compact
  def __call__(self, inputs: Array) -> Array:
    # If an int is given for a dimension's padding, apply it symmetrically.
    padding = [[p, p] if isinstance(p, int) else p for p in self.padding]
    # Don't pad the batch dim or features dim.
    padding = [[0, 0]] + padding + [[0, 0]]
    return jnp.pad(inputs, padding)
# Not technically a subclass of nn.Conv since it shouldn't have the 'features' 
# or 'feature_group_count' attrs, but very close.
class DepthwiseConv(nn.Module):
  depth_multiplier: int
  kernel_size: Union[int, Iterable[int]]
  strides: Optional[Iterable[int]] = nn.Conv.strides
  padding: Union[str, Iterable[Tuple[int, int]]] = nn.Conv.padding
  input_dilation: Optional[Iterable[int]] = nn.Conv.input_dilation
  kernel_dilation: Optional[Iterable[int]] = nn.Conv.kernel_dilation
  use_bias: bool = nn.Conv.use_bias
  dtype: Dtype = nn.Conv.dtype
  precision: Any = nn.Conv.precision
  kernel_init: Callable[[PRNGKey, Shape, Any], Array] = nn.Conv.kernel_init
  bias_init: Callable[[PRNGKey, Shape, Any], Array] = nn.Conv.bias_init

  @nn.compact
  def __call__(self, inputs: Array) -> Array:
    feature_group_count = inputs.shape[-1]
    features = int(self.depth_multiplier * feature_group_count)
    return nn.Conv(
        features=features,
        kernel_size=self.kernel_size,
        strides=self.strides,
        padding=self.padding,
        input_dilation=self.input_dilation,
        kernel_dilation=self.kernel_dilation,
        feature_group_count=feature_group_count,
        use_bias=self.use_bias,
        dtype=self.dtype,
        precision=self.precision,
        kernel_init=self.kernel_init,
        bias_init=self.bias_init)(inputs)

Would be happy to add any of these and/or the architecture implementations of the MobileNet models as examples.

marcvanzee commented 3 years ago

Hi @phoenix-meadowlark sorry for the late reply! I'm personally fine adding them if they make using Flax easier and they don't complicate our API, which in this case doesn't seem to be happening.

I'd also like to hear what @jheek thinks of this.

jheek commented 3 years ago

Currently we only define something as a Module if it actually has variables or rngs. So in this case flatten, global_avg_pool, and zero_pad would become functions and not Modules (this makes them more widely applicable). They very closely match existing NumPy functions. So far our philosophy has been to write code that looks like numpy more than make code that looks like PyTorch.

For DepthwiseConv I'm not quite sure. Using feature_group_count is a nice parameterization because it extends to group convolutions as well. Again I think this depends mostly on what people are used to. An alternative would be to make an explicit example pattern in the Conv docstring on how to do depthwise (and/or grouped) convolutions.

marcvanzee commented 2 years ago

Closing this due to inactivity.