Closed nss-ysasaki closed 2 years ago
False alarm. It was a bug in the weight conversion process. The weight shape of the GitHub model I've shown above was wrong. It is consistent with the TF-Hub model, like such:
blocks_1
blocks_1/efficientnetv2-m
blocks_1/efficientnetv2-m/blocks_1
blocks_1/efficientnetv2-m/blocks_1/conv2d
blocks_1/efficientnetv2-m/blocks_1/conv2d/kernel:0, (3, 3, 24, 24)
blocks_1/efficientnetv2-m/blocks_1/tpu_batch_normalization
blocks_1/efficientnetv2-m/blocks_1/tpu_batch_normalization/beta:0, (24,)
blocks_1/efficientnetv2-m/blocks_1/tpu_batch_normalization/gamma:0, (24,)
blocks_1/efficientnetv2-m/blocks_1/tpu_batch_normalization/moving_mean:0, (24,)
blocks_1/efficientnetv2-m/blocks_1/tpu_batch_normalization/moving_variance:0, (24,)
I have a model trained using a TF-Hub version of EffNetV2-M (https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_m/feature_vector/2), and I want to load the model with this repository's
effnetv2_model.EffNetV2Model
. I am doing this because I need the output of intermediate layers (for model explanation purposes), which TF-Hub API does not provide access to.Then I noticed that shapes of weights from TF-Hub and GitHub models differ, so I am unable to transfer the model from TF-Hub to GitHub;
block_1
weights in TF-Hub:What is causing this shape mismatch? Was the models published on TF-Hub trained with hyperparams different from what this repository offers?
The code to extract shapes is (roughly) as follows:
pretrained_model = hub.KerasLayer( CORE_LAYER_PATH, trainable=True, input_shape=[*IMAGE_SIZE, 3], load_options=load_locally)
model = tf.keras.Sequential([ tf.keras.layers.Lambda( lambda data: tf.image.convert_image_dtype(data, tf.float32), input_shape=[*IMAGE_SIZE, 3]), pretrained_model, tf.keras.layers.Dense(len(CLASSES), activation='softmax') ])
Print weight names and shapes
names = [weight.name for layer in model.layers for weight in layer.weights] weights = {name: weight for name, weight in zip(names, model.get_weights())}
for name in sorted(weights.keys()): print(f"{name} ({weights[name].shape})")