kbelenky / open_sorts

Sorting machine for TCGs such as Magic: The Gathering
GNU General Public License v3.0
19 stars 2 forks source link

Report: Comparative Accuracy Analysis #5

Open HanClinto opened 1 year ago

HanClinto commented 1 year ago

Apologies in advance for opening this as an "Issue", but it seems to be as good of a place as anywhere for this discussion to happen. Absolutely love this project, and I love everything you're doing with it. I hope I can eventually contribute in some way!

I have compared the accuracy of Open Sorts' embedding-generator vs. pHash on a dataset that I have compiled based on card images taken with cell phones and uploaded for sale on an online marketplace. The card listings annotates the image with the expected image, and then I used traditional homography alignment to locate the reference card image and generate a transformation matrix to create a normalized head-on image of each card.

I create reference databases for each of those embeddings using reference images scraped from Scryfall, and store them in an Annoy index for fast lookup. Could just as easily do a brute-force lookup, but I wanted to experiment with ANN libraries to see if they helped at all. :)

I then fed the images from my test dataset (36,271 images) into the two embedding algorithms -- one being open_sorts and the other being phash (using the ImageHash library ), using hash size of 16 and highfreq_factor of 1 and calculated nearest-neighbors. I look at the top 1 and top 5 nearest neighbors and calculate accuracy. I wanted to experiment with distance metrics, so in the interest of science I'm also including the results obtained using various distance metrics so that you have this as a reference.

Results (higher numbers are better):

Generator Distance Metric Top-1 Acc (Card Name) Top-1 Acc (Card Set / Printing) Top-5 Acc (Card Name) Top-5 Acc (Card Set / Printing)
phash hamming 0.9964162646450724 0.7790489317711923 0.997987594762233 0.9880634045485872
open_sorts dot 0.9646037215713301 0.7682150241212956 0.9772570640937285 0.9564162646450723
open_sorts angular 0.9638594073053067 0.7673880082701585 0.9757684355616816 0.9550379048931771
open_sorts euclidean 0.9637767057201929 0.7674707098552722 0.9755754651964162 0.9545692625775327
open_sorts manhattan 0.9629772570640938 0.7657064093728463 0.9758235699517575 0.955093039283253

Card name accuracy is determining that this card is indeed a particular name / artwork ("This is a Sol Ring!") and card set / printing accuracy is matching it up with correct edition of the card ("This is a Sol Ring from Commander 2018!").

I'll admit -- I was hoping that open_sorts would perform better here, but it's accuracy is still quite impressive!

I need to run more tests on this -- in particular, I'm curious what will happen if the card alignment is less than ideal, and measuring the new accuracy once we leverage the corner detector and we have the "fuzz" that is induced from that jostling the card image around. I strongly suspect that once we're dealing with less-than-ideally-aligned images that open_sorts will outperform phash, but I still need to set up that experiment.

Anyways, I had a mixture of emotions when looking at these results, but figured they were perhaps worth sharing regardless.

Thank you again for the open_sorts project, and for letting us follow along with you here!

kbelenky commented 1 year ago

I think it's great that you posted this as an issue. It lets us have a durable public record of your work, and invites public comment.

Your analysis looks excellent.

Top-1 Acc (Card Name) was my objective when I built the model. Because of the model design, training requirements and data augmentations I do on the scryfall images, many of the training images are not even human recognizable for their set. By the time a card image has been geometrically distorted, exposed to fake lighting conditions, artificial camera noise, and shrunk to the relatively small input image size, the set symbol is just an indistinct blob.

It's only a happy accident that Top-1 Acc (Card/Set) even sometimes works at all, and it's often only for single-printing cards.

Anecdotally, I am a little surprised that Card Name accuracy is as low as 96%. My experience under a narrow set of conditions (a single camera, and the lighting in my basement office) is that it's closer to 100% for all "reasonable" images. That could also just be the bias of a proud creator, however :)

A couple thoughts on where inaccuracy might be creeping in:

  1. I wonder how "approximate" Annoy is.
  2. I wonder about data labeling error, particularly with respect to misspellings, or foreign language.
  3. Your dataset may contain more damaged cards than mine. I haven't tested mine on any cards with significant damage.
  4. My tests have a very homogenous background behind the card. It's possible my model is not robust to little bits of cluttered background around the edges.

re: Card alignment That was actually one of the two primary goals of my training dataset design (the other being variable lighting). Since I built the card recognizer before I built a reliable position detector, I had to make the model robust to a lot of errors in card alignment. The model you're using was trained on images where the card corners could extend as far as 5% past the corner of the input image, and be as much as 10% inside the corner of the input image, independently selected for each of the 8 x/y coordinates of the card corners.

I've got a roadmap in my mind for improvements that I can make to the model. I haven't done them yet because Google made some significant breaking changes to Colab's TPU support, and my current home computer (an i7 2600K with 16G of RAM) is just not up to the task. However, I'm getting a new personal machine in a month or two that should unlock home training.

  1. Larger input image size, making the set symbols actually visible.
  2. Revised model design to prioritize set recognition.
  3. A revised training process that emphasizes challenging edge cases
  4. Slightly less aggressive data augmentation.

But... before I do the improved model training, I'm working on an improved Lego mechanism that should be simpler to build, and more reliable about feeding single-cards. Right now, for the sorting machine, the unreliability of the recognizer is absolutely dwarfed by the unreliability of the card feeder.

HanClinto commented 1 year ago

Anecdotally, I am a little surprised that Card Name accuracy is as low as 96%. My experience under a narrow set of conditions (a single camera, and the lighting in my basement office) is that it's closer to 100% for all "reasonable" images.

I don't think that's simply creator-bias -- some of these images are distorted to the point of not being unreasonable. Many of the failed images are foils with bad lighting, and stretch the definition of "reasonable image".

I wonder how "approximate" Annoy is.

It's a fair question! When using Annoy, there are two primary parameters to tune -- the num_trees when building the index, and the search_k used at lookup time. Increasing num_trees increases the amount of time required to generate the index, increases disk space for the index, and increases RAM needed when doing lookups in the index. Increasing search_k only increases the amount of time required to do a lookup at runtime. I've done some searching in my use of pHash to find the optimal parameters, and for my case I've settled on a num_trees value of 100, and a search_k of 8000. YMMV, but much below these numbers and I see a drop-off in terms of accuracy, and above these numbers I see no improvement, while still giving me fast lookup times. The lookup times on my machine probably won't mean much on their own, and I don't have a brute force implementation coded up right now (though it wouldn't take long). When/if I get around to it, I'll try to give you some performance numbers so you can see if it's worth it. My guess is that brute-force search is plenty performant enough.

I haven't exactly confirmed that these num_tree and search_k parameters are optimal for opensorts, but I should probably test that next.

Regardless, I'll upload a few examples of the misses here so you can see the sorts of things that are tripping it up.

Forest (PGPX #2018f)

rectified_E218F3A8-829C-4EE7-98ED-D8CE6C26D809

Correct: (0.6880239642568995) Retrieval #1: Fervent Champion (PELD #124p) (0.7123503088951111) Retrieval #2: Fervent Champion (ELD #124) (0.7054118514060974)
714a0a11-837a-4c44-9d37-d418b60e8f12 24dc2c6b-1152-4b20-a5bf-d89cf1be8036 c52d66db-5570-48a1-99cf-e0417517747b
Retrieval #3: Forest (PGPX #2018f) (0.6880239844322205) Retrieval #4: Pilgrim's Eye (CLB #333) (0.6819046139717102) Retrieval #5: Forerunner of the Heralds (RIX #129) (0.6737490892410278)
714a0a11-837a-4c44-9d37-d418b60e8f12 32161267-e12b-454f-a7e1-94e078566ffa 30bc2bd2-adfc-490e-998a-303598e6a942

Cleansing Nova (PM19 #9s)

rectified_E5B713AC-8495-4A76-863B-1A8685222937

Correct: (0.7155052257558054) Retrieval #1: Murasa Behemoth (MH1 #172) (0.7357884049415588) Retrieval #2: Cleansing Nova (C20 #83) (0.7310032844543457)
d25eb278-8d4e-4e26-a4bc-eb210d3ec1eb 480ddde1-81d3-4939-b232-cb1ced6cfc4d 9b576ac4-a4b4-4d2c-820c-cf9fcc0c57c0
Retrieval #3: Darksteel Forge (2XM #248) (0.7238464951515198) Retrieval #4: Cleansing Nova (KHC #20) (0.7180673480033875) Retrieval #5: Cleansing Nova (PM19 #9s) (0.7155052423477173)
421089c4-c8d3-48c5-b313-fb1741546271 990a7ec3-0ace-4245-a374-690d11a827ae d25eb278-8d4e-4e26-a4bc-eb210d3ec1eb

Mentor of the Meek (PM19 #27s)

rectified_3A78A327-1CEC-4579-A281-073BC472A6AC

Correct: (0.44822438521259755) Retrieval #1: Neonate's Rush (MID #151) (0.6848520636558533) Retrieval #2: Baldur's Gate (CLB #345) (0.6670352220535278)
35d6761b-6221-47e7-807d-666924f04dc8 dee17e12-e08f-4449-9f49-05f20e0d1670 2436aa14-9200-4295-8041-b682cf3c4216
Retrieval #3: Baldur's Gate (PCLB #345s) (0.6558581590652466) Retrieval #4: Loxodon Punisher (MRD #14) (0.6528931856155396) Retrieval #5: Canopy Tactician (KHM #378) (0.649572491645813)
bcbed9d0-622d-4307-8c2b-39a7c7dfaa42 df2b90b8-8306-4543-b9f4-3cfd033f5ca5 3eaf48c9-09bc-4d81-a3a5-432219a71754

Mistcaller (PM19 #62s)

rectified_D1B17236-63EC-4A27-8C3F-3BCB3F8AF07F

Correct: (0.5829754005046084) Retrieval #1: Rishkar, Peema Renegade (NEC #126) (0.6660542488098145) Retrieval #2: Luminarch Aspirant (PZNR #24s) (0.6462628841400146)
0972287a-8c68-4661-8e59-cb0d5c06670d c0bbcf3b-5b7c-4846-b30c-100542ce2204 42bdb2bd-fb25-4715-9754-42ab0c5c8cf7
Retrieval #3: Dismantle (DST #57) (0.646224856376648) Retrieval #4: Font of Vigor (JOU #11) (0.6308972835540771) Retrieval #5: Satyr Wayfinder (ZNC #81) (0.628696620464325)
9d9915ff-a9b9-429c-bdc1-6a52d9f0e6d4 d4ef3a8e-ef8b-417a-a3a9-cb0ce88cb0c9 84cb8203-88e4-4a3b-9334-91b70c747091

Bone Dragon (PM19 #88s)

rectified_C93E6138-C0F1-46E4-AFD4-BF7F1CA5CD44

Correct: (0.17490760869623045) Retrieval #1: ['Elemental // Elemental', 'Elemental', 'Elemental'] (TUST #17) (0.689578652381897) Retrieval #2: Time Warp (M10 #75) (0.6529827117919922)
004b44af-9b27-4689-a6b6-bcd3ad0aca7e 70f1f745-9fc2-41a6-9d39-fb4964595cf5 f21c7628-0571-4c09-8763-8e434e5e87d2
Retrieval #3: Harness the Storm (SOI #163) (0.6417365074157715) Retrieval #4: Dragon Bell Monk (MB1 #84) (0.634276270866394) Retrieval #5: Dragon Bell Monk (IMA #17) (0.6339633464813232)
5294d359-c599-40ed-9e06-2a3cc8624d6a 5a4db84d-8aa2-4cb0-a30f-5f1abd2136b5 d1244c95-a066-412d-bd04-c906ea4b4dd0

Death Baron (PM19 #90s)

rectified_96129B3C-2B75-4A37-811E-E9FB3FADD88A

Correct: (0.37648487193644087) Retrieval #1: Wetland Sambar (KTK #62) (0.7395081520080566) Retrieval #2: Robo- (UST #157) (0.6994038224220276)
bece1a3b-1ab5-4b34-9d4a-65fcef40f880 f71a86e0-d15a-4fba-94f6-bfbaade8d837 b4d54442-caca-412d-a716-032eaa587944
Retrieval #3: Deadeye Plunderers (XLN #220) (0.6993348002433777) Retrieval #4: Sandsteppe Citadel (CP3 #6) (0.6922216415405273) Retrieval #5: Balustrade Spy (IMA #80) (0.6849275827407837)
63a7a1a4-aec2-467d-91a1-1a2605718c7c 7ce265fa-1b23-4d03-8675-6c9447d31115 d295ef8c-fe8f-49f2-8588-7f5782315fc7
HanClinto commented 1 year ago

I hope that's not too much image spam. As you can tell from the five sample images I uploaded, they're not exactly the best-case scenarios. Some bad lighting, some warped cards that look like taco shells, some very dark cards, some foreign languages, some pre-release stamps, etc. It's a smorgasbord of substandard image delights. :D

So I would guess that you're correct -- in most practical cases, the accuracy is going to be much higher than 97%. I suspect that my dataset just has some real potatoes of card images, because it's encompassing card images that were taken with cell phones in all manner of circumstances.

I like your roadmap of improvements to the model design. I don't have a ton of GPU at my disposal, but I have a bit (a 3090 Ti), and I'm happy to lend it if it would be useful for this purpose.

As far as better ways to emphasize challenge edge cases, I've been building a triplet dataset with this sort of thing in mind -- focusing a lot on reprinted cards that have only subtle differences -- almost building a curriculum of cards that I would use when training a new store employee or something (I don't work at an LGS, but if I did, this is the approach I might take).

kbelenky commented 1 year ago

These examples are really fascinating.

One thing I notice from the scores is that they're all below 0.8 for both the target and the best-match. The model was trained with an alpha of 0.2 for the triplet loss, which can be (more or less) interpreted as "any match score below 0.8 is not a match". If we weren't playing a forced-choice matching game, then the model would be saying "I don't know" for these examples.

You could go even deeper by incorporating a more dedicated quality signal. I already have one such signal, but I haven't bothered to incorporate it into open_sorts because it presents a new question of, "Ok... you've detected a potato. What do you do with it?"

As for training future versions of the model, I've debated with myself what to do with cards from different languages. The current model was trained without any non-English cards at all.

I suspect model robustness could be improved if cards that are otherwise the same, but printed in different languages are treated as the same card, but it would then come at a cost of limiting the utility of the model. On the flip side, if I label different languages as different cards, I'll probably need to significantly increase the feature space of the model to accommodate the distinctions, which may have consequences for robustness, and will definitely have consequences for training.

And of course, that's all moot until I can set up a working training system. Triplet loss training likes really big batch sizes (at least 1200, and really 4000 is better). I think that a consumer PC with 128G of RAM should be just big enough if I use a bfloat16 model. My suspicion is that to build larger models than the current MobileNetV2 one, I'm probably going to have to move to a hybrid offline/online triplet loss training system. I've got a plan for how to do that, but it'll take some work.

HanClinto commented 1 year ago

I like the idea of an incorporated quality signal, but I'm not terribly familiar with those. How do you train those, and how is it different from the regular distance metric? But if it works, then that seems like it would be handy. Especially if users happen to feed non-Magic cards into the system (random filler cards, Pokemon cards, etc) then it would make sense to have a robust rejection filter.

I suspect model robustness could be improved if cards that are otherwise the same, but printed in different languages are treated as the same card, but it would then come at a cost of limiting the utility of the model. On the flip side, if I label different languages as different cards, I'll probably need to significantly increase the feature space of the model to accommodate the distinctions, which may have consequences for robustness, and will definitely have consequences for training.

Or alternatively, don't do it all in a single model, but for the primary image recognition model, train it to focus on artwork + set symbols by treating all languages as the same, and then train a separate model to only do language differentiation?

Re: the large batch sizes, one thing that I'm wondering is if we could use something like GradientTape to do gradient accumulation and get the same effect as large batch sizes, but just with less GPU overhead. I've got a small set of example code to do just this, and I'll probably try it out on my next training run (hope to start that in the next week or so).

One thing I notice from the scores is that they're all below 0.8 for both the target and the best-match. The model was trained with an alpha of 0.2 for the triplet loss, which can be (more or less) interpreted as "any match score below 0.8 is not a match". If we weren't playing a forced-choice matching game, then the model would be saying "I don't know" for these examples.

That's fascinating. I'm not familiar with the concept of alpha in triplet loss training, so that's helpful to read about. Thank you!

I flipped through and found a few examples where the retrieval of the incorrect card is > 0.8 -- here are some examples in case you're interested in seeing these. I will note that the number of errors where the incorrect retrieval is > 0.8 is somewhere around 3% of errors -- the vast majority of the time, incorrect identifications are below the alpha threshold. I had to go through over 100 error cases before I found four to paste here.

Snow-Covered Plains (MH1 #250) - Listo ID 582C0AB4-35B2-48D4-A5B3-64E0531B86D4

rectified_582C0AB4-35B2-48D4-A5B3-64E0531B86D4

Correct: (0.627380720814017) Retrieval #1: Snow-Covered Forest (MH1 #254) (0.8073512315750122) Retrieval #2: Double-Faced Substitute Card (SSTX #7) (0.7443360090255737)
7a961768-6166-4852-b518-23eb4cced47d 1c59fc48-704b-4187-b9d3-2a2cff6dd54b a4a26a66-5948-4092-b5ef-1fe578415093
Retrieval #3: Double-Faced Substitute Card (SSTX #8) (0.7186621427536011) Retrieval #4: Swamp (ZNR #272) (0.7022615671157837) Retrieval #5: Double-Faced Substitute Card (SKHM #6) (0.6948244571685791)
c2dfee13-6ee0-40f9-9869-8ce1609f839a 95a58ce4-e07f-4c9c-98ae-3173d6d63cc5 63608547-3bb6-4ada-b6c2-90015670a2f8

Field of the Dead (PM20 #247s) - Listo ID 1EBFD719-0DCE-4C72-847E-602CEE06066F

rectified_1EBFD719-0DCE-4C72-847E-602CEE06066F

Correct: (0.7729615295245367) Retrieval #1: Grim Strider (AKH #94) (0.8310505151748657) Retrieval #2: Field of the Dead (PM20 #247s) (0.7729616165161133)
54cec453-21f1-48a1-a395-9ec172b8f13a 97ae6769-2b2b-48e3-9503-e9744984743a 54cec453-21f1-48a1-a395-9ec172b8f13a
Retrieval #3: Field of the Dead (M20 #247) (0.7638400793075562) Retrieval #4: Field of the Dead (PM20 #247p) (0.762450098991394) Retrieval #5: Hall of Storm Giants (PAFR #257a) (0.7541294097900391)
470ca3f4-29aa-4c4c-8ff2-8cdd70c69943 68619895-e0bc-477e-b419-b188b6515768 97a8d685-8fd7-49d8-90ca-f2c4de735d26

Temple of Abandon (PTHB #244s) - Listo ID 8CF5CF9E-97F7-481B-80B0-E49D8B6C0608

rectified_8CF5CF9E-97F7-481B-80B0-E49D8B6C0608

Correct: (0.7607066032433067) Retrieval #1: Time Sieve (2XM #223) (0.8154834508895874) Retrieval #2: Temple of Abandon (THB #244) (0.7875157594680786)
37ec420c-bae6-40d6-b0d8-30742c14753b c2e8b424-0cec-490e-a571-bd051f952adf 7d5f8481-47f7-4531-9dad-686cdfb5d2ad
Retrieval #3: Temple of Abandon (PTHB #244p) (0.7772296667098999) Retrieval #4: Temple of Abandon (PTHB #244s) (0.7607066035270691) Retrieval #5: Myrkul's Edict (CLB #135) (0.7440133094787598)
e240d766-06a0-4a6f-ac59-7b2d33d524fe 37ec420c-bae6-40d6-b0d8-30742c14753b adc7c427-aa31-4077-8397-74a9a3802ee7

Historian of Zhalfir (M21 #325) - Listo ID C13E4B13-91AD-4E3F-A865-4A93E4589B15

rectified_C13E4B13-91AD-4E3F-A865-4A93E4589B15

Correct: (0.7245299202177193) Retrieval #1: Venomous Changeling (MH1 #114) (0.8220256567001343) Retrieval #2: Historian of Zhalfir (M21 #325) (0.7245299816131592)
ae981da0-f32c-49d5-bcb0-2b9255a4e1fe 4c5a1d73-d102-469b-82ca-ec18f616375e ae981da0-f32c-49d5-bcb0-2b9255a4e1fe
Retrieval #3: Suntail Hawk (10E #50★) (0.6614135503768921) Retrieval #4: Suntail Hawk (10E #50) (0.6609879732131958) Retrieval #5: Snakeskin Veil (STA #120) (0.6458617448806763)
18d4bbec-c20b-437f-88e2-d609fd7c6003 b4886566-af41-4d14-8ae1-ce2952db8e42 e909bd04-5394-4888-aa4a-4f90855663bb
kbelenky commented 1 year ago

Unfortunately, the quality signal that I'm most familiar with might be a trade secret (it's existence isn't a trade secret, but its implementation and training regimen may be. I have to do some asking around before I can share it).

I don't think using gradient accumulation will work for increasing triplet loss batch sizes. Tensorflow's implementation of TripletSemihardLoss uses the large batch for online triplet mining. It can only find triplets that are all within the same batch. It would be more effort than I'm willing to invest to develop my own cross-batch online triplet mining (although it could be done). Instead, I'm going to opt for a multistep approach where I use a simple model.predict on a number of small batches to create a large number of potential triplets, and then pack the best triplets into GPU-sized batches. I'll be increasing the amount of compute load, but reduce the memory requirements.

HanClinto commented 1 year ago

Unfortunately, the quality signal that I'm most familiar with might be a trade secret (it's existence isn't a trade secret, but its implementation and training regimen may be. I have to do some asking around before I can share it).

👍 Very understandable! I'm dealing with similar things on my end, but sharing as much as I can.

I don't think using gradient accumulation will work for increasing triplet loss batch sizes. Tensorflow's implementation of TripletSemihardLoss uses the large batch for online triplet mining. It can only find triplets that are all within the same batch. It would be more effort than I'm willing to invest to develop my own cross-batch online triplet mining (although it could be done). Instead, I'm going to opt for a multistep approach where I use a simple model.predict on a number of small batches to create a large number of potential triplets, and then pack the best triplets into GPU-sized batches. I'll be increasing the amount of compute load, but reduce the memory requirements.

That's really clever! I like that.

At that point, it almost seems like it would be not significantly more work to just build ones' own triplet miner, and write ones' own triplet mining code. I think this is the approach that I'm going to try this week (simply because my GPU availability doesn't allow for terribly large batch sizes, so I'm trying to find ways to make gradient tape work). I'll let you know what my results turn up. In first attempt (months ago) at triplet training, I built my own triplets, but I only used card metadata as my approach to choosing which cards to select for positive / negative samples, and it's probably not as good as scanning the dataset and doing actual loss comparisons to find truly hard / semi-hard triplets.

Also, you mentioned before:

The model was trained with an alpha of 0.2 for the triplet loss...

Is the concept of alpha the same as "margin" as explained in this page? The concept of margin makes sense, but I haven't yet learned how one chooses an intelligent margin value.

kbelenky commented 1 year ago

I'm finally having a little bit of time to work on this project again. No results to show, but I've learned some things.

I don't think my idea for doing offline pre-mining of triplets is going to go very far. It might end up being useful at the end of the training process for refining the model to handle extremely tricky pairings, but it's not useful for the bulk of the training. The issue is that the small batch sizes end up giving it the same effective behavior as pure offline triplet mining. Pure offline triplet mining, while technically correct, is just too slow.

I think you were on to something with both of your ideas of "write one's own triplet mining" and gradient accumulation.

My next idea is to write a memory-efficient version of online triplet mining, by splitting a large macro-block into manageable chunks, and then using gradient accumulation to put the chunks back together.

The basic outline looks like this:

For each big batch:

  1. Make predictions with a large batch size, and no gradient tape
  2. Compute all of the losses using traditional online triplet mining
  3. Initialize the gradient accumulator.
  4. Cut the batch into smaller chunks, small enough to run with a gradient tape.
  5. For each small chunk, run a prediction with the gradient tape.
  6. (for each small chunk) Apply the losses that were computed in (2)
  7. (for each small chunk) Accumulate the gradient from the losses into the gradient accumulator.
  8. After all the chunks have been processed, apply the gradient to the optimizer.

Like my last idea, it's substituting two-passes, double the compute, to overcome limitations in RAM.

At this rate, I might have something to show for it by December.

kbelenky commented 1 year ago

Actually, my proposal above is a terrible idea that won't work.

The problem with it is that the losses are computed outside the gradient tape. The connection between the loss function and the gradients is severed, so the optimizer can't do its work.

Back to the drawing board.

HanClinto commented 1 year ago

I was able to do a bit of training, but only on a subset of the problem (distinguishing one printing of a card from another). I have gotten modest results (finally have a converging training curve and it seems usable once I actually write inference code around it!) -- but it's still struggling to outperform pHash -- I think I need to either expand my dataset, or else find better image augmentations -- I'm guessing that's my current limiting factor, because tweaking hyperparameters or simply training more only results in overfitting or getting to the same results I already have -- just more slowly.

So if I've got a dataset of 10,000 images or so, I then break it up into groups, where each group is easily confused for each other. Some groups have 20 images, some groups have 200 images. So I have a jagged 2-dimensional array of image groups. Each image group includes enough representations of classes such that one could build meaningful triplets in it. However, I can't load 200 images in a batch at a time -- if I set my batch size much above 32, then it starts to fail. So batch size is set to relatively small 32.

I then wrote my own data generator that loops through the array of arrays, and for each jagged array, outputs batches that are samplings from each group, and only from within that group.

def data_generator(grouped_data, batch_size, loop_generator=True):
  while True:
    for group in grouped_data:
      # Go through each group in `batch_size` chunks
      for start_index in range(0, len(group), batch_size):
        y = []
        # Yield an array that is batch_size long
        for inx in range(start_inx, start_inx + batch_size):
          y.append(group[inx % len(group)]
        # Can insert code to ensure you have a good batch, such as ensuring enough examples of each label.
        # If there aren't at least 2 unique y-lables, then can just `continue` and go to the next one with no harm done -- you'll pick it up on the next time through.

        x = preprocess(x, shape, do_augment=True)
        yield x,y

    if not loop_generator:
      break # Important because while the training set does not need to come to an end, the test set does need to have an endpoint, so loop your training generator, but not your testing generator.         

This is rough pseudocode, but hopefully you get the idea. At the end of it, it outputs a set of small GPU-sized batches that probably have interesting triplets in them (not necessarily guaranteed to be "hard" triplets, but likely to at least be "interesting").

I then feed these batches into my model that uses tensorflow_addons.losses.TripletSemiHardLoss like normal, except set to a more reasonable batch size.

If I could feed the entire group into my model, then I would guarantee that the triplet miner could find the best triplets in each step, but that's one of the limitations of this approach.

In my train_step of my model, I then accumulate a GradientTape like normal, and I do not apply them right away. If I want to simulate a batch_size of 3200, then I only apply my GradientTape every 100 training steps (side note that I couldn't use a normal variable to track my current step -- I had to learn how to use tf variables and tf.cond for that, because some of this code runs on the GPU).

Anyways, the net effect is that we have effectively negated much of the benefits of stochastic gradient descent, and instead turned it into plain-old-fashioned gradient descent -- but we gain the noise-tolerance of larger batch sizes. It slows down our training, but we make steadier progress. Tuning the hyperparameter of how much one scales the accumulated gradients is another way that we can tweak the learning rate.

The net result was that I was able to get an actually converging training curve for my triplet loss on a consumer-sized GPU. Using a stock ImageNet-trained version of MobileNetv2 as the base, I was able to get reasonable results in about 3-4 days training on a 3090ti. And it actually sorta works! Kinda! :) I don't have results worth sharing yet, because they're still not significantly better than pHash -- but at least it's the first time I've trained something that's in the same ballpark, and I didn't have to use an A100 to do it!

Like I said earlier, I'm pretty sure I either need to work on my augmentation code or expand my dataset, and I'm currently focusing on the latter. No idea how long that will take, but once the new dataset is built out, it should take another week for training, and then... maybe... hopefully... I might have results worth sharing about. :)

This may or may not be useful to you -- but figured you'd appreciate hearing my approach. I actually started writing my own triplet loss mining and was feeling a bit overwhelmed, and settled on this halfway-compromising approach that still lets me use the convenience of TripletSemiHardLoss and effectively simulate the larger batch sizes.

And to think... all of this could have been avoided if I'd simply rented a GPU. :sigh:

Is it weird that the biggest thing holding me back from renting a GPU is not the money -- it's the ridiculous size of these data sets? Once I get datasets into the realm of hundreds-of-gigabytes, I just don't want to be dealing with ephemeral Colab instances, and certainly don't want to be re-uploading to it every time it accidentally times out and I lose my progress.

kbelenky commented 1 year ago

I'm in the same boat where all of this is a little silly because in my day job I have access to full TPU pods that could brute force the training in under a minute, I just don't want to use work resources for personal projects.

I haven't figured out why, but Google's TPUs seem to be doing some kind of sidestep around the memory requirements of GradientTape. If you look at the specs of TPUs, they don't actually have that much RAM, but they are able to handle far larger batch sizes than an equivalently sized GPU.

The main issue I've found with using Google's public Colab and TPU training is that it's not well integrated into the rest of the cloud infrastructure. The TPUs can read from GCS buckets, but there are some serious caveats:

  1. They don't properly support permission delegation, so you have to set up your GCS bucket with read-permission granted to a single, global user account that's the same for everyone in the world running on a Colab TPU. So anyone running a TPU on colab could read your data (not that I really care, since for me it's just data I scraped from Scryfall)

  2. You can't control where the TPU instance is running. If you're not careful, you can end up with a lot of cross-region traffic (some of the TPUs get allocated in Singapore, and I'm hosting in US-east). There is a way to check the IP address of your TPU, do a geographic lookup, and then re-roll the dice by restarting your kernel, but it's tedious and stupid.

Both of those problems could be alleviated if I paid to run my own Colab kernels as compute instances, but part of the fun is trying to do this on the cheap.

As for expanded data, if you give me a few days, I can probably get you a couple hundred thousand images augmented with my pipeline. What resolution do you need them at?

HanClinto commented 1 year ago

That's very insightful about the Colab limitations, thank you!

It feels like it wouldn't be that hard for the Colab developers to make things a bit simpler -- I already pay for an upgraded Google Storage account -- I would LOVE to be able to just pickle my entire Colab instances to my Drive so that I can easily resume it if they time me out for idleness, and have my results automagically saved from my last training session. I know I can theoretically write stuff to do that (and I'm sure there are nice templates out there for exactly this out there that people have made), but it's all just... hassle. And like you said, I want to learn how to do this with my own hardware as well. :)

As for expanded data, if you give me a few days, I can probably get you a couple hundred thousand images augmented with my pipeline. What resolution do you need them at?

Oooh, that would be sweet! I'm guessing you're not at liberty to share your augmentation pipeline directly, but the output of that would be more than enough!

Currently I'm training with images that are 299,299, so that's all I would need for now. That said, I'm starting to wonder if I'm trying to read set symbols (and possibly copyright dates, and the other small indicators that we use to distinguish one printing from another), then I might need to increase this resolution. Regardless, this is what the pre-trained MobileNetv2 uses, and I'm not entirely clear on how changing the input resolution of pre-trained models works for transfer learning -- I don't know everything that goes on inside of Keras when I do that stuff.

So my current dataset format has the columns: filename, set, illustration_id

I'm grouping the cards together with illustration_id (so all batches will contain cards with the same artwork), and then my y-value is the set code. So the dataset is saying: "This Japanese M20 and this German M20 are the same card, but this Japanese BFZ reprint with the same artwork is different."

This has the benefit that I also don't need to train on every card - only cards that are in Scryfall that have the same illustration ID -- so cards with artworks that have never seen a reprint can be excluded from the dataset. If this isn't convenient for you to do, then no problem -- my group-creation code will auto-filter groups that don't have at least two y-labels.

HanClinto commented 1 year ago

Also, I was noodling more about your proposal for how you might do a similar hybrid approach, and I don't think you're that far off from what I'm doing, and I think the ideas could be merged in a useful way.

In my case, I broke my giant dataset up once (at the beginning of the program) into a large jagged 2D array by using metadata about the cards (cards that share the same artwork). But in your case, you could just as easily do this cluster-and-grouping within the data generator loop itself.

At the top of the data generator loop, first run inference on the whole dataset using the latest version of the model. Cluster images together that have nearby vectors and put all of those images into a single group -- this will be your pool of negatives. Also add any other images that share the same label as images in that group, but don't have nearby vectors (this will be your pool from which the batch will find hard positives). Hopefully the clustering code can be tuned to have roughly 50-100 images per group or so, but larger groups probably aren't bad. Shuffle those groups and deal them out -- yielding batches of size N.

The whole idea of my approach is that I only need to make small batches that are close to hard triplets... and then be lazy and let TripletSemiHardLoss actually take care of the hassle of finding and building the actual triplets themselves.

When the data generator gets through all of the groups, it loops, goes back to the top, and right below the while True: is where we run inference on the entire dataset again, re-cluster, and re-group. Rinse, repeat.

Then we don't need to play games with applying weights separately within each chunk and also to the optimizer later. We go through our training steps like normal, accumulating our weights to a gradient tape, and apply them every N steps (currently I'm applying my weights 4x per epoch and it seems to work fine -- plenty of room for experimentation here though). It's definitely slower than pure SDG, but at least I can effectively train with triplet loss. I can see good results in just a few epochs (takes ~4 hrs per epoch at my current dataset size) and then after a few days I'm basically trained to completion.

I can share some pseudocode of what my train_step looks like if that would be helpful. It's all derived from that example code that I shared earlier, but I had to add some shenanigans using tf.Variable and whatnot to actually get the accumulation to trigger every N steps properly.