Open SimonBiggs opened 3 years ago
Wow this issue has been around for half a year? :sweat_smile:
I just stumbled across the same issue, the following behaviour is super weird:
a = jnp.zeros([1, 8, 8, 8])
pool1 = hk.max_pool(a, 2, 2, 'SAME') # => (1, 4, 4, 8)
pool2 = hk.max_pool(a, [2, 2], [2, 2], 'SAME') # => (1, 8, 4, 4)
Welp, I guess the docs do indeed specify that channel_axis
is only used if window_shape
/strides
are an integer. It's still super confusing, but actually documented.
I got the same confusion as well, the source comes from _infer_shape
.
MaxPool is calling max_pool, which eventually calls _infer_shape.
But in MaxPool
the argument description is only
channel_axis: Axis of the spatial channels for which pooling is skipped.
while in max_pool
, the description is
channel_axis: Axis of the spatial channels for which pooling is skipped, used to infer
window_shape
orstrides
if they are an integer.
This means that, if window_shape
or strides
are not ints, the channel_axis
argument is ignored.
import jax.numpy as jnp
import haiku as hk
@hk.testing.transform_and_run()
def f(x):
max_pool = hk.MaxPool(
window_shape=(
2,
2,
2,
),
strides=(
2,
2,
2,
),
padding="VALID",
channel_axis=-1,
)
return max_pool(x)
x = jnp.ones((2, 4, 6, 8, 3))
print(f(x).shape)
This prints
(2, 4, 3, 4, 1)
In order to get the shape right, we need to pass full shapes (1,2,2,2,1)
, or only ignoring batch axis (2,2,2,1)
. But the docstring is not clear for this.
Current _infer_shape
has following behaviour:
from typing import Union, Optional, Sequence, Tuple
import jax.numpy as jnp
def _infer_shape(
x: jnp.ndarray,
size: Union[int, Sequence[int]],
channel_axis: Optional[int] = -1,
) -> Tuple[int, ...]:
"""Infer shape for pooling window or strides."""
if isinstance(size, int):
if channel_axis and not 0 <= abs(channel_axis) < x.ndim:
raise ValueError(f"Invalid channel axis {channel_axis} for {x.shape}")
if channel_axis and channel_axis < 0:
channel_axis = x.ndim + channel_axis
return (1,) + tuple(size if d != channel_axis else 1
for d in range(1, x.ndim))
elif len(size) < x.ndim:
# Assume additional dimensions are batch dimensions.
return (1,) * (x.ndim - len(size)) + tuple(size)
else:
assert x.ndim == len(size)
return tuple(size)
x = jnp.ones((2, 4, 6, 8, 3))
print(_infer_shape(x, size=(2,2,2)))
print(_infer_shape(x, size=(2,2,2,1)))
print(_infer_shape(x, size=(1,2,2,2,1)))
This prints
(1, 1, 2, 2, 2)
(1, 2, 2, 2, 1)
(1, 2, 2, 2, 1)
We should at least update the docstring. Otherwise, the following two situations should not happen together
When a window is defined as such:
It gets misaligned and applies the max-pooling along the channel axis. I believe the issue is in the following line of code:
https://github.com/deepmind/dm-haiku/blob/7964d01f1c0dd907c8ea016ad1d1cc7ae48ac05d/haiku/_src/pool.py#L46-L47
Pulling that out gives the following result:
Cheers :slightly_smiling_face:, Simon