cvg / glue-factory

Training library for local feature detection and matching
Apache License 2.0
756 stars 98 forks source link

Conversion of trained checkpoint to official lg checkpoint #48

Closed spagnoloG closed 11 months ago

spagnoloG commented 11 months ago

Hello, firstly i would express thanks for providing the training framework to the public! :)

Let me firstly explain what i did, and then what issue I am facing:

data:
    name: custom
    data_dir: ""
    metadata_dir: "" 
    train_size: null
    val_size: null
    batch_size: 128
    num_workers: 14
    homography:
        difficulty: 0.7
        max_angle: 359
    photometric:
        name: lg
model:
    name: two_view_pipeline
    extractor:
      name: gluefactory_nonfree.superpoint
      max_num_keypoints: 2048
      detection_threshold: 0.0
      nms_radius: 3
      trainable: False
    ground_truth:
        name: matchers.homography_matcher
        th_positive: 3
        th_negative: 3
    matcher:
      name: matchers.lightglue
      features: superpoint
      depth_confidence: -1
      width_confidence: -1
      filter_threshold: 0.1
      flash: true
train:
    seed: 0
    epochs: 5
    log_every_iter: 100
    eval_every_iter: 500
    lr: 1e-4
    lr_schedule:
        start: 20
        type: exp
        on_epoch: true
        exp_div_10: 10
benchmarks:
    hpatches:
      eval:
        estimator: opencv
        ransac_th: 0.5

And finetuned it for mere 5 epochs.

Then I wrote a custom python script that converts the best_checkpoint.tar to the same structure as official checkpoint (yes I went and manually inspected the layers)

# !/usr/bin/env python3
# -*- coding: utf-8 -*-

import torch

def extract_checkpoint(checkpoint_path, save_model_path):
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    print("Checkpoint Contents:")
    converted_checkpoint = {}
    for key, value in checkpoint.items():
        if key == "model":
            for k, v in value.items():
                if "matcher.transformers." in k:
                    # Extract the transformer layer number and the rest of the key
                    parts = k.split(".")
                    transformer_layer = parts[2]
                    remaining_key = ".".join(parts[3:])

                    # Construct the new key name
                    new_key = f"self_attn.{transformer_layer}.{remaining_key}"
                    converted_checkpoint[new_key] = v
                elif k.startswith("matcher."):
                    # For other matcher parts, just remove 'matcher.' prefix
                    new_key = k.replace("matcher.", "")
                    converted_checkpoint[new_key] = v
                else:
                    print(k)

    torch.save(converted_checkpoint, save_model_path)

if __name__ == "__main__":
    checkpoint_path = "checkpoint_best.tar"
    save_model_path = "lg_finetuned_v4.pth"
    extract_checkpoint(checkpoint_path, save_model_path)

After I loaded the converted checkpoint to the official Lightglue repository, it did not produce any matches (tested on 1k images), keypoints were succesfully extracted though.

I am a bit worried that I took the wrong direction somewhere, and would really appreciate your guidance!

thanks:)

Zhaoyibinn commented 11 months ago

I have encountered a similar problem with you. I replaced the data in the file I trained with the data in the official PTH and conducted several epochs of training. At the beginning of the training, the loss displayed was very high and quickly decreased (yes, the official weight training shows a high loss at the beginning). After testing, there were no matching points. After debugging, it was found that the mscores became very low, and the filter_threshold needs to be adjusted 0.001 so that it will have a certain matching effect.

May I ask if you have currently resolved this issue?

spagnoloG commented 11 months ago

@Zhaoyibinn Nope the problem still persists :/. I loaded the local weights in the same manner as you did in the freshly opened issue https://github.com/cvg/glue-factory/issues/49 . Will let you know if I find out anything interesting (not before the weekend) . Thanks for your info! :)

Zhaoyibinn commented 11 months ago

Hello, I have resolved my issue ^ ^. You can refer to issue #49 for reference (although it may be different from the problem you encountered)

spagnoloG commented 11 months ago

<3 <3 Thanks it works!

Here is the final conversion script :D

import torch

def extract_checkpoint(checkpoint_path, save_model_path, n_layers):
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    print("Checkpoint Contents:")
    state_dict = checkpoint.get('model', {})

    matcher_dict = {k.split('matcher.', 1)[1]: v for k, v in state_dict.items() if k.startswith('matcher')}

    if matcher_dict:
        for i in range(n_layers):
            patterns = [
                (f"transformers.{i}.self_attn", f"self_attn.{i}"),
                (f"transformers.{i}.cross_attn", f"cross_attn.{i}")
            ]

            for old_key, new_key in patterns:
                matcher_dict = {k.replace(old_key, new_key) if old_key in k else k: v for k, v in matcher_dict.items()}

    print(matcher_dict.keys())

    torch.save(matcher_dict, save_model_path)

if __name__ == "__main__":
    checkpoint_path = "checkpoint_best.tar"
    save_model_path = "lg_finetuned_v7.pth"
    n_layers = 9  # Just like in official repo
    extract_checkpoint(checkpoint_path, save_model_path, n_layers)
Zhaoyibinn commented 10 months ago

<3 <3 谢谢它有效! 这是最终的转换脚本:D

import torch

def extract_checkpoint(checkpoint_path, save_model_path, n_layers):
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    print("Checkpoint Contents:")
    state_dict = checkpoint.get('model', {})

    matcher_dict = {k.split('matcher.', 1)[1]: v for k, v in state_dict.items() if k.startswith('matcher')}

    if matcher_dict:
        for i in range(n_layers):
            patterns = [
                (f"transformers.{i}.self_attn", f"self_attn.{i}"),
                (f"transformers.{i}.cross_attn", f"cross_attn.{i}")
            ]

            for old_key, new_key in patterns:
                matcher_dict = {k.replace(old_key, new_key) if old_key in k else k: v for k, v in matcher_dict.items()}

    print(matcher_dict.keys())

    torch.save(matcher_dict, save_model_path)

if __name__ == "__main__":
    checkpoint_path = "checkpoint_best.tar"
    save_model_path = "lg_finetuned_v7.pth"
    n_layers = 9  # Just like in official repo
    extract_checkpoint(checkpoint_path, save_model_path, n_layers)

完成训练后,模型本身不会输出 .pth 文件。函数是获取 .pth 文件的唯一方法吗?完成训练后,我只有检查点.tar文件。

您好,tar文件包含了pth文件的内容,您可以直接读取,也可以将其读取之后重新保存为pth

zyxzyx45 commented 2 months ago

Hello, can you share the official pre-training weights? Thank you very much