olgaliak / active-learning-detect

Active learning + object detection
MIT License
100 stars 33 forks source link

ability to target a specific class for tagging #11

Closed abfleishman closed 6 years ago

abfleishman commented 6 years ago

With the kittiwake model, I have now run many iterations and it is performing very well for kittiwakes. I have a second class in the model (egg) which is much rarer ( i think I have only found ~5 examples) it would be really great if I could specifically download photos that have predictions for a certain class (egg) so that I could find more of that class (i have not seen a single prediction for egg in the random photos that I have tagged).

yashpande commented 6 years ago

Hey Abram,

There is actually a way to do this, although it involves modifying the code:

In line 13 of create_predictions.py, modify:

def calculate_confidence(predictions): return min([float(prediction[0]) for prediction in predictions])

to be something like:

def calculate_confidence(predictions): if any(prediction[1]=="egg" for prediction in predictions): return 0 return min([float(prediction[0]) for prediction in predictions])

Basically each element in predictions is of the form: confidence, classname, xmin, xmax, ymin, ymax So you can use this to come up with a custom confidence. In the code above I simply set confidence to be 0 if the class is egg (make sure the spelling/capitalization is the same as in your config file) so those images are sent first since they have lowest confidence.

This is not as simply as it could be, and also requires re-running the prediction code through active_learning_train.sh, so I'm keeping the issue open as a potential future improvement.

Hope this helps!

-Yash

abfleishman commented 6 years ago

Thanks! Definitely helps! I will give this a try. Is there a way to predict without retraining?

yashpande commented 6 years ago

Yes, you can simply run:

python ${python_file_directory}/create_predictions.py cur_config.ini az storage blob upload --container-name $label_container_name --file $untaggedoutput --name totag$(date +%s).csv --account-name $AZURE_STORAGE_ACCOUNT --account-key $AZURE_STORAGE_KEY

Note that this requires cur_config.ini to be up to date and all the $variables to be defined. If they are not currently (you can test this by doing something like cat cur_config.ini / echo $untagged_output) then you can do:

set -a sed -i 's/\r//g' ../config.ini . ../config.ini set +a

envsubst < ../config.ini > cur_config.ini

(This assumes that your config file is ../config.ini).

This is a lot more hacky, and will only work if you have just run training using the same config.ini file. Otherwise, you will likely run into issues.

-Yash

olgaliak commented 6 years ago

New setting "ideal_class_balance" should address the scenario without retraining. Closing the issue ;)