google-deepmind / dm-haiku

JAX-based neural network library
https://dm-haiku.readthedocs.io
Apache License 2.0
2.91k stars 231 forks source link

`hk.max_pool` window misalignment #125

Open SimonBiggs opened 3 years ago

SimonBiggs commented 3 years ago

When a window is defined as such:

x = hk.max_pool(
    value=x,
    window_shape=(2, 2),
    strides=(2, 2),
    padding="VALID",
    channel_axis=-1,
)

It gets misaligned and applies the max-pooling along the channel axis. I believe the issue is in the following line of code:

https://github.com/deepmind/dm-haiku/blob/7964d01f1c0dd907c8ea016ad1d1cc7ae48ac05d/haiku/_src/pool.py#L46-L47

Pulling that out gives the following result:

>>> from haiku._src.pool import _infer_shape

>>> shape = (2,16,16,5)
>>> x = np.ones(shape)
>>> window_size = (2, 2)
>>> channel_axis = -1

>>> _infer_shape(x, window_size, channel_axis)
(1, 1, 2, 2)

Cheers :slightly_smiling_face:, Simon

khdlr commented 3 years ago

Wow this issue has been around for half a year? :sweat_smile:

I just stumbled across the same issue, the following behaviour is super weird:

a = jnp.zeros([1, 8, 8, 8])
pool1 = hk.max_pool(a, 2, 2, 'SAME')            # => (1, 4, 4, 8)
pool2 = hk.max_pool(a, [2, 2], [2, 2], 'SAME')  # => (1, 8, 4, 4)
khdlr commented 3 years ago

Welp, I guess the docs do indeed specify that channel_axis is only used if window_shape/strides are an integer. It's still super confusing, but actually documented.

mathpluscode commented 1 year ago

I got the same confusion as well, the source comes from _infer_shape.

Description

MaxPool is calling max_pool, which eventually calls _infer_shape.

But in MaxPool the argument description is only

channel_axis: Axis of the spatial channels for which pooling is skipped.

while in max_pool, the description is

channel_axis: Axis of the spatial channels for which pooling is skipped, used to infer window_shape or strides if they are an integer.

This means that, if window_shape or strides are not ints, the channel_axis argument is ignored.

Reproduction

import jax.numpy as jnp
import haiku as hk

@hk.testing.transform_and_run()
def f(x):
    max_pool = hk.MaxPool(
        window_shape=(
            2,
            2,
            2,
        ),
        strides=(
            2,
            2,
            2,
        ),
        padding="VALID",
        channel_axis=-1,
    )
    return max_pool(x)

x = jnp.ones((2, 4, 6, 8, 3))
print(f(x).shape)

This prints

(2, 4, 3, 4, 1)

In order to get the shape right, we need to pass full shapes (1,2,2,2,1), or only ignoring batch axis (2,2,2,1). But the docstring is not clear for this.

Current _infer_shape has following behaviour:

from typing import Union, Optional, Sequence, Tuple
import jax.numpy as jnp

def _infer_shape(
    x: jnp.ndarray,
    size: Union[int, Sequence[int]],
    channel_axis: Optional[int] = -1,
) -> Tuple[int, ...]:
    """Infer shape for pooling window or strides."""
    if isinstance(size, int):
        if channel_axis and not 0 <= abs(channel_axis) < x.ndim:
            raise ValueError(f"Invalid channel axis {channel_axis} for {x.shape}")
        if channel_axis and channel_axis < 0:
            channel_axis = x.ndim + channel_axis
        return (1,) + tuple(size if d != channel_axis else 1
                            for d in range(1, x.ndim))
    elif len(size) < x.ndim:
        # Assume additional dimensions are batch dimensions.
        return (1,) * (x.ndim - len(size)) + tuple(size)
    else:
        assert x.ndim == len(size)
        return tuple(size)

x = jnp.ones((2, 4, 6, 8, 3))
print(_infer_shape(x, size=(2,2,2)))
print(_infer_shape(x, size=(2,2,2,1)))
print(_infer_shape(x, size=(1,2,2,2,1)))

This prints

(1, 1, 2, 2, 2)
(1, 2, 2, 2, 1)
(1, 2, 2, 2, 1)

Suggestion

We should at least update the docstring. Otherwise, the following two situations should not happen together