Closed kaczmarj closed 1 year ago
looks like the conversion script had to be updated to transpose the tf weights in fc6.
def convert_tf_to_pytorch(input_path, output_path, num_classes: int):
try:
ckpt = tf.train.load_checkpoint(input_path)
except tf.errors.DataLossError:
raise RuntimeError(
"Error: could not load checkpoint. Did you pass in the stem of the path?"
"Pass in the path without '.index' or '.meta' or '.data-00000-of-00001'."
)
new_state_dict = {}
for tf_prefix, torch_prefix in tf_to_pytorch_layers:
tf_weights = f"{tf_prefix}/weights"
tf_biases = f"{tf_prefix}/biases"
torch_weights = f"{torch_prefix}.weight"
torch_biases = f"{torch_prefix}.bias"
tf_weight_array = ckpt.get_tensor(tf_weights)
tf_bias_array = ckpt.get_tensor(tf_biases)
if "conv" in tf_weights:
tf_weight_array = tf_weight_array.transpose([3, 2, 0, 1])
elif "fc" in tf_weights:
if tf_weights == "vgg_16/fc6/weights":
# [7, 7, 512, 4096] -> [25088, 4096]
tf_weight_array = tf_weight_array.transpose([2, 0, 1, 3])
tf_weight_array = tf_weight_array.reshape((25088, 4096))
# E.g., go from shape [1, 1, 4096, 1000] to [1000, 4096]
tf_weight_array = tf_weight_array.squeeze().T
new_state_dict[torch_weights] = torch.from_numpy(tf_weight_array)
new_state_dict[torch_biases] = torch.from_numpy(tf_bias_array)
# Test that conversion was (probably) done correctly.
true_model = torchvision.models.vgg16()
true_model.classifier[6] = torch.nn.Linear(4096, num_classes)
if true_model.state_dict().keys() != new_state_dict.keys():
raise RuntimeError(
"Something went wrong... converted model keys do not match TorchVision"
" VGG16 keys."
)
true_state_dict = true_model.state_dict()
for true_k, new_k in zip(true_state_dict, new_state_dict):
true_shape = true_state_dict[true_k].shape
new_shape = new_state_dict[new_k].shape
if true_shape != new_shape:
raise ValueError(
"Shape mismatch between converted parameters and reference parameters."
)
torch.save(new_state_dict, output_path)
this issue tracks the addition of the vgg16 tils model.
weights were pulled from Box: https://stonybrookmedicine.app.box.com/folder/128593362243?v=til-results-new-model
config for the original vgg16 tils model is here: https://stonybrookmedicine.app.box.com/file/757453637020
pixel values are normalized to [-1, 1] (see the code below from the Box folder with the original model).
the weights were then converted from tensorflow to pytorch using the conversion script in this repo. some names of weights were renamed to conform to timm using the code below: