SunnyHaze / IML-ViT

Official repository of paper “IML-ViT: Benchmarking Image manipulation localization by Vision Transformer”
MIT License
184 stars 23 forks source link

checkpoint from main_train.py is mismatch with the one need in demo.ipynb #17

Closed linohzz closed 4 months ago

linohzz commented 4 months ago

作者您好,我采用CASIAV2进行训练,CASIAV1作为验证,使用第108轮的权重文件在demo.ipynb上可视化示例图。遇到了以下问题: A9112B81D3B0D5CA0B071F01F476D3AD

图片可视化效果几乎不可见,使用训练集中的数据进行demo,效果也如上。 这是我的训练log.txt 17BB9177863ECBFEE73F2B1F8A69E319

使用您提供的权重文件进行可视化demo时,效果都正常

SunnyHaze commented 4 months ago

因为main_train.py得到的checkpoint是按照字典形式组织的,包括了modeloptimizer等等很多参数,会导致整个checkpoint大小超过1GB。其形式大致如下:

{
    "model": <state_dict of model>,
    "optimizer": <state_dict of model>,
    ......
}

为了节省大小,demo.ipynb默认只读取model参数,也就是说它期待的读取的文件类型就是本身。所以你可以通过修改demo.ipynb中torch.load相关的部分,从字典中提前把model字段读取出来实现,为了验证,可以把model.load_state_dict()中的strict参数改为True,这样只有严格match才能继续运行;或者提前将checkpoint中的重新torch.save为一个仅包含该object的的文件给demo.ipynb读取即可。

这里是我使用的重新save checkpoint的样例`Python脚本:

import torch
model = torch.load("/mnt/data0/xiaochen/workspace/IML-VIT/output_dir/checkpoint-150.pth")
output = model['model']
torch.save(output, "checkpoint-150.pth")

English version:

The issue is that the checkpoint obtained from main_train.py is organized in dictionary format, including many parameters such as model, optimizer, and so on. This leads to a size checkpoint over 1GB. Its format is roughly as follows:

{
    "model": <state_dict of model>,
    "optimizer": <state_dict of model>,
    ......
}

To save space, demo.ipynb by default only reads the model parameters. In other words, it expects to read a PyTorch checkpoint file that only contains an object of . So, you can achieve this by modifying the torch.load() section in demo.ipynb to extract the model field from the dictionary in advance. For validation purposes, you can change the strict parameter in model.load_state_dict() to True, so that only strict matching can proceed, or you can save the from the checkpoint as a file containing only that object in advance and have demo.ipynb read it.

Here is an example Python script I used to re-save the checkpoint:

import torch
model = torch.load("/mnt/data0/xiaochen/workspace/IML-VIT/output_dir/checkpoint-150.pth")
output = model['model']
torch.save(output, "checkpoint-150.pth")
SunnyHaze commented 4 months ago

因为这可能是一个commen issue,所以我这里额外提供一个英文的version。 希望能帮到你。如果有其他问题feel free to reach out!

linohzz commented 4 months ago

已解决,效果符合预期,感谢!

此外:还有两个小疑问: 1、训练过程中:验证数据集时,如下会出现两个F1值 图片 括号中的和括号外的,如上图所示:0.6108(0.5860)这两个值代表什么,0.5860代表整个验证集上平均的F1,0.6108代表这个print_freq 中这一个频率中的平均值吗? 2、是否方便提供评估代码,虽说AUC值对篡改的评估不是很准确,但现在大部分工作还都是使用F1和auc作为对比

SunnyHaze commented 4 months ago
  1. 0.5860是torch.distribute reduce过后的结果,前面的数值只是第一个GPU上得到的结果。
  2. AUC评估代码大差不差,就是用sklearn实现的,参考这个issue:#8, 只是目前实现有点丑,最近有点忙没有精力revise这一部分到可以开源的情况。敬请谅解,后续有时间会给出来的。
linohzz commented 4 months ago
  1. 0.5860是torch.distribute reduce过后的结果,前面的数值只是第一个GPU上得到的结果。

    1. AUC评估代码大差不差,就是用sklearn实现的,参考这个issue:Regarding interpretation of evaluation metrics #8, 只是目前实现有点丑,最近有点忙没有精力revise这一部分到可以开源的情况。敬请谅解,后续有时间会给出来的。

收到,感谢!