facebookresearch / swav

PyTorch implementation of SwAV https//arxiv.org/abs/2006.09882
Other
2.01k stars 280 forks source link

How to use the checkpoint after training? #51

Closed hahmad2008 closed 3 years ago

hahmad2008 commented 3 years ago

I tried to train swav with a small dataset, and I got these generated files:

If I have the model after training how can I use it? how to assign an unseen image to one of those clusters and how to retrieve images from the same cluster?

I used this command for training:

python -m torch.distributed.launch --nproc_per_node=1 main_swav.py \
--data_path pics1 \
--epochs 5 \
--base_lr 0.6 \
--final_lr 0.0006 \
--warmup_epochs 0 \
--batch_size 32 \
--size_crops 224 96 \
--nmb_crops 2 6 \
--min_scale_crops 0.14 0.05 \
--max_scale_crops 1. 0.14 \
--use_fp16 true \
--freeze_prototypes_niters 5005 \
--queue_length 3840 \
--epoch_queue_starts 15
Erfun76 commented 3 years ago

refer to evaluating training

I tried to train swav with a small dataset, and I got these generated files:

  • checkpoints
  • stats0.pkl
  • params.pkl
  • train.log

If I have the model after training how can I use it? how to assign an unseen image to one of those clusters and how to retrieve images from the same cluster?

I used this command for training:

python -m torch.distributed.launch --nproc_per_node=1 main_swav.py \
--data_path pics1 \
--epochs 5 \
--base_lr 0.6 \
--final_lr 0.0006 \
--warmup_epochs 0 \
--batch_size 32 \
--size_crops 224 96 \
--nmb_crops 2 6 \
--min_scale_crops 0.14 0.05 \
--max_scale_crops 1. 0.14 \
--use_fp16 true \
--freeze_prototypes_niters 5005 \
--queue_length 3840 \
--epoch_queue_starts 15

refer to train evaluation

mathildecaron31 commented 3 years ago

Hello @hahmad2008

The trained models are located in the checkpoints folder.

You can find the cluster assignment of an image x by taking a forward of that image with the model:

 embedding, output = model(x)

with output of size [1, K] where K is the number of clusters (default 3000).

By operating a argmax operation on output you will know the cluster assignment of x