SmilingWolf / JAX-CV

Repo for my JAX CV experiments. Mostly revolving around the Danbooru20xx dataset
24 stars 4 forks source link

About finetuning #14

Open gudwns1215 opened 1 month ago

gudwns1215 commented 1 month ago

Hi! thank you for sharing this awesome training codes!

I was trying to train my custom tagger, and in the codebase there is a pre-train model training and finetuning traing,

is there any difference between uploaded tagger (ex. https://huggingface.co/SmilingWolf/wd-swinv2-tagger-v3) and pre-train model?

In the code there is a reset-head option so, I thought the base is same but head is different between pre-train model and tagger, but there is possibility that tagger's main body module is trained also.

and Can I also get advice on stable loss values?

while fine-tuning I get 6.x or something like that on validation, and it seems wrong.

onirhakin commented 3 weeks ago

can i ask where you did find the finetuning training?

SmilingWolf commented 2 weeks ago

The pretraining code uses the SimMIM pretraining strategy.

All the models on huggingface are finetunes, ie. the final product of training_loop.py.

SimMIM-pretrained models have a slightly different structure from "normal" ones, so they have to be converted when using them for classification training. This is handled in training_loop.py when using the --restore-simmim-ckpt option.

The --reset-head option is generally used together with --restore-params-ckpt and --freeze-model-body to

I usually train the final dense layer first for 3-6 epochs, then the rest of the model if necessary.

I'd like to have a more flexible unfreezing interface at some point, but I have yet to come up with a design I like.

Re: loss value, the danbooru dataset currently in use gives about 48.3 validation loss with my best models. A smaller dataset with a smaller number of classes gives about 20-23 val loss depending on the model.

Keep in mind the loss is averaged across batch samples, but summed over classes, that might be why it looks like it makes little sense. That's why I also track MCC and F1Score out of the box. They depend on a threshold, but are a bit easier to grasp.