apple / turicreate

Turi Create simplifies the development of custom machine learning models.
BSD 3-Clause "New" or "Revised" License
11.2k stars 1.14k forks source link

Limitations on number of categories #1054

Closed hipwelljo closed 6 years ago

hipwelljo commented 6 years ago

Is there a maximum number of categories that can be trained with Turi Create for image classification? Looking through the documentation in regards to the different pretrained models you can use for transfer learning, Resnet, Squeezenet, and VisionFeaturePrint_Screen all have 1000 categories. Does this mean models created from those also have a max of 1000 categories that can be classified? In my case I have 6000 and am wondering if this will work or be problematic. If so this information would be great to add to to the guide here:
https://apple.github.io/turicreate/docs/userguide/image_classifier/how-it-works.html#pretrained-image-classifiers Thanks!

TobyRoseman commented 6 years ago

@hipwelljo - The number of labels in the pertained network has no affect on the max number of categories. We haven't done any testing around the max number of categories for an image classifier. That being said, 6000 categories is probably going to be problematic.

I agree we should have some guidance in our documentation about the max number of categories. When I get a chance I'll run some tests and update the documentation.

If you try creating an image classifier with 6,000 categories (or any other large number of categories), please let us know how it goes.

For a large number of categories, I would recommend using Squeezenet. The number of features it outputs is much smaller than Resnet or VisionFeaturePrint_Screen.

hipwelljo commented 6 years ago

@TobyRoseman Thanks! I will plan to try Squeezenet with my 6000 categories and report back, perhaps next weekend I could give it a whirl. I'm currently creating a model with Create ML, so using vision feature print, and it's been training for 24 hours now... 😮 Do you suspect it would be quicker with Turi Create using Squeezenet due to the smaller number of features it outputs?

For some context, the data set is handwritten characters. Just now found in the documentation for Create ML feature extractors, it states in regards to vision feature print: "It isn’t suitable for character recognition, for example, where the input images are highly binary in nature". (This would be good to document in Turi Create as well!) Is that the case for Squeezenet too? From what I read, it sounds like Resnet, Squeenzenet, and VisionFeaturePrint used real-world photographs, so maybe none of them would be ideal for creating character recognition models? Sounds like Squeezenet is the best out of the current options though?

srikris commented 6 years ago

With that many categories, the bottleneck is likely not going to be in the feature extraction phase. Its going to be in the fine tuning phase (which is linear in the number of categories). SqueezeNet may be 2-4x faster but not that much more. The amount of data you would need to distinguish between each class might also be more than if you had just 50-100 classes.

hipwelljo commented 6 years ago

@TobyRoseman and @srikris I have results for you with nearly 6000 classes using squeezenet. It took about 9 hours to complete 10 iterations, compared to 21 hours with CreateML using vision feature print scene. However, the accuracy at iteration 10 for TC was not as good as the accuracy of iteration 3 of CML, which took about 9 hours to complete. So, it seems it's twice as fast to iterate, but you have to iterate just as long to obtain the same level of accuracy. Maybe more data would help as noted - I had 10 images for each class.

Turi Create with SqueezeNet:

Logistic regression:
--------------------------------------------------------
Number of examples          : 49264
Number of classes           : 5742
Number of feature columns   : 1
Number of unpacked features : 1000
Number of coefficients      : 5746741
Starting L-BFGS
--------------------------------------------------------
+-----------+----------+-----------+--------------+-------------------+---------------------+
| Iteration | Passes   | Step size | Elapsed Time | Training Accuracy | Validation Accuracy |
+-----------+----------+-----------+--------------+-------------------+---------------------+
| 0         | 1        | NaN       | 1410.938340  | 0.000142          | 0.000410            |
| 1         | 4        | 0.000101  | 6993.121902  | 0.000284          | 0.000000            |
| 2         | 6        | 1.000000  | 11684.099158 | 0.008038          | 0.002048            |
| 3         | 7        | 1.000000  | 14630.014436 | 0.056796          | 0.013923            |
| 4         | 8        | 1.000000  | 17736.467060 | 0.026713          | 0.005733            |
| 5         | 9        | 1.000000  | 20273.712870 | 0.072791          | 0.022113            |
| 6         | 10       | 1.000000  | 22618.025692 | 0.112049          | 0.037674            |
| 7         | 11       | 1.000000  | 25096.515562 | 0.178528          | 0.066339            |
| 8         | 12       | 1.000000  | 27513.523471 | 0.241698          | 0.088043            |
| 9         | 13       | 1.000000  | 29956.004078 | 0.290983          | 0.110156            |
| 10        | 14       | 1.000000  | 32559.100921 | 0.343476          | 0.139640            |
+-----------+----------+-----------+--------------+-------------------+---------------------+

CreateML with VisonFeaturePrintScene:

+-----------+--------------+-------------------+---------------------+
| Iteration | Elapsed Time | Training Accuracy | Validation Accuracy |
+-----------+--------------+-------------------+---------------------+
| 0         | 2624.198231  | 0.000188          | 0.000000            |
| 1         | 16721.113751 | 0.060564          | 0.001757            |
| 2         | 25586.915525 | 0.287247          | 0.042179            |
| 3         | 31857.741821 | 0.436031          | 0.180316            |
| 4         | 38179.388461 | 0.523650          | 0.226714            |
| 5         | 44416.936162 | 0.565639          | 0.224956            |
| 6         | 50695.630575 | 0.589598          | 0.235852            |
| 7         | 56999.740079 | 0.588754          | 0.245694            |
| 8         | 63322.979367 | 0.606953          | 0.251670            |
| 9         | 69659.051465 | 0.637535          | 0.255185            |
| 10        | 76013.447991 | 0.654571          | 0.260105            |
| 11        | 82391.099810 | 0.661626          | 0.275923            |
| 12        | 88707.731178 | 0.672620          | 0.279438            |
| 13        | 95048.078870 | 0.688981          | 0.284359            |
| 14        | 101366.747...| 0.702996          | 0.296661            |
| 15        | 107704.889...| 0.701833          | 0.297012            |
| 16        | 114031.010...| 0.705267          | 0.290334            |
| 17        | 123009.369...| 0.708813          | 0.284710            |
| 18        | 129335.099...| 0.719563          | 0.299121            |
| 19        | 135653.758...| 0.736974          | 0.297012            |
| 20        | 141968.815...| 0.747312          | 0.295606            |
| 21        | 148326.322...| 0.749207          | 0.301582            |
| 22        | 154664.151...| 0.756131          | 0.301582            |
| 23        | 161016.337...| 0.762603          | 0.301933            |
| 24        | 167352.737...| 0.769583          | 0.302285            |
| 25        | 173658.381...| 0.774198          | 0.298067            |
+-----------+--------------+-------------------+---------------------+
srikris commented 6 years ago

@hipwelljo More data will help the accuracy but you can play around with the base model and the max_iterations parameter (I would set it to 25 to all of them)

You can try 3 different base models in Turi Create

model = tc.image_classifier.create(
               train_data, target='label', model='squeezenet_v1.1', max_iterations = 25)

As mentioned before, SqueezeNet would be 2-4x faster and not that much more. Resnet may be more accurate than VisionPrint but probably slower.

I'll file two issues to help make this better:

I'm closing the issue but you an continue to post your results here.