tensorflow / similarity

TensorFlow Similarity is a python package focused on making similarity learning quick and easy.
Apache License 2.0
1.01k stars 104 forks source link

#303 fix dtype policy bug in GEM layers. #304

Closed owenvallis closed 1 year ago

owenvallis commented 1 year ago

GEM layers create a general pooling layer in the init, but we didn't pass the kwargs. This means the general pooling layer didn't have the dtype policy. This caused the GEM layers to fail when using a mixed_float dtype policy as the general pooling layer returns float32 and the GEM dtype policy is float16.

The fix is to pass all kwargs onto the general pooling layer.

Note: We also pin the mypy version to stop the new typing errors. These are handled in the dev branch changes.