facebookresearch / mae

PyTorch implementation of MAE https//arxiv.org/abs/2111.06377
Other
6.93k stars 1.17k forks source link

Poor Image Reconstruction After Fine-tuning on MvTec Dataset #168

Closed cestbonsuliu closed 3 months ago

cestbonsuliu commented 12 months ago

I am trying to perform anomaly localization after fine-tuning on the MvTec dataset. Here are some of my settings:

The file structure of the MvTec dataset is as follows:

.
├── train
│   ├── bottle
│   ├── cable
│   ├── capsule
│   ├── carpet
│   ├── grid
│   ├── hazelnut
│   ├── leather
│   ├── metal_nut
│   ├── pill
│   ├── screw
│   ├── tile
│   ├── toothbrush
│   ├── transistor
│   ├── wood
│   └── zipper
└── val
    ├── bottle
    ├── cable
    ├── capsule
    ├── carpet
    ├── grid
    ├── hazelnut
    ├── leather
    ├── metal_nut
    ├── pill
    ├── screw
    ├── tile
    ├── toothbrush
    ├── transistor
    ├── wood
    └── zipper

The "train" folder contains normal samples, while the "val" folder contains abnormal samples.

My fine-tuning command is:

OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=3 main_finetune.py --accum_iter 4 --batch_size 32 --model vit_base_patch16 --finetune checkpoint/mae_pretrain_vit_base.pth --epochs 100 --blr 5e-4 --layer_decay 0.65 --weight_decay 0.05 --drop_path 0.1 --mixup 0.8 --cutmix 1.0 --reprob 0.25 --dist_eval --data_path /home/hehui/AnomalyDetection/Dataset/MvTec --nb_classes 15

Then, I tested the following image:

image

pre_checkpoint_url = '/home/hehui/AnomalyDetection/Project/mae/output_dir/checkpoint-99.pth'
pre_checkpoints = torch.load(pre_checkpoint_url, map_location=torch.device('cpu'))

mae_base = models_mae.mae_vit_base_patch16()

mae_base.load_state_dict(pre_checkpoints['model'], strict=False)

loss, pred, mask = mae_base(input)
print(loss)

pred = mae_base.unpatchify(pred)

display_image(pred)

The obtained results are as follows:

image

Can anyone help me with this issue?

daisukelab commented 12 months ago

You need to read the paper to understand what is being done.

  1. The pre-training makes the encoder and decoder work together to reconstruct masked patches.
  2. The fine-tuning makes only the encoder fit the downstream task data distribution.

Then in the fine-tuning stage, the entire MAE should lose the reconstruction capability because the encoder should fine-tune its embedding output to fit the downstream data, regardless of what the decoder expects.

If you pre-train on your dataset, then the MAE will be able to learn to reconstruct your data, but the same goes. Fine-tuning to something will break the reconstruction. No fine-tuning if you need the reconstruction.

Be careful you are using a dataset big enough for pre-training.

cestbonsuliu commented 3 months ago

thanks