OlafenwaMoses / ImageAI

A python library built to empower developers to build applications and systems with self-contained Computer Vision capabilities
https://www.genxr.co/#products
MIT License
8.48k stars 2.18k forks source link

continue training on classification or transfer weight to expand dataset regulary #808

Open dudutti opened 1 year ago

dudutti commented 1 year ago

Hi to begin, great work, thanks.

I have a question to train model on classification, not detection.

If i expand my dataset when i add new photo but same number of class, i want to transfert learning like train_from_pretrained_model from custom Object detection but for Custom training prediction. I've tried to pass in trainModel() : model_trainer.trainModel(num_experiments=100, batch_size=32, train_from_pretrained_model=Myfilepath)

or i've tried .setTrainConfig() with the same result, in Custom training prediction these parameter doesn't exist

FIY : the documenation have mistake in https://imageai.readthedocs.io/en/latest/custom/index.html the example of parameter in this code : model_trainer.trainModel(num_objects=10, num_experiments=100, enhance_data=True, batch_size=32, show_network_summary=True)

trainModel() does not support anything than num_experiments and batch_size, enhance_data=True and show_network_summary=True make an error ClassificationModelTrainer.trainModel() got an unexpected keyword argument 'show_network_summary'

I use the lat version of ImageAI in repository by cloning with pip install imageai --upgrade v3.03

Many thanks, this question is just to retrained my model from the last instead of retrained all, it's a gain of 6 day of gpu computing with 1080TI

My project is to develop a program to mushroom classification, i have a dataset of 45 000 photo with 40 families and 100 species. Everything work great but i need to expand my dataset regulary

dudutti commented 1 year ago

The reason is the size of state_dict from mismatch original model with custom trained model

The solution, I think is to apply strict=False when load_state_dict but i cannot find where. Still have issue

https://pytorch.org/tutorials/beginner/saving_loading_models.html#id4

iI was try on each loading self.__model.load_state_dict(state_dict, strict=False) but didn't solve the error