GEM layers create a general pooling layer in the init, but we didn't pass the kwargs. This means the general pooling layer didn't have the dtype policy. This caused the GEM layers to fail when using a mixed_float dtype policy as the general pooling layer returns float32 and the GEM dtype policy is float16.
The fix is to pass all kwargs onto the general pooling layer.
Note: We also pin the mypy version to stop the new typing errors. These are handled in the dev branch changes.
GEM layers create a general pooling layer in the init, but we didn't pass the kwargs. This means the general pooling layer didn't have the dtype policy. This caused the GEM layers to fail when using a mixed_float dtype policy as the general pooling layer returns float32 and the GEM dtype policy is float16.
The fix is to pass all kwargs onto the general pooling layer.
Note: We also pin the mypy version to stop the new typing errors. These are handled in the dev branch changes.