Clay-foundation / model

The Clay Foundation Model (in development)
https://clay-foundation.github.io/model/
Apache License 2.0
242 stars 25 forks source link

Learning to find similar semantics from embeddings #186

Closed brunosan closed 1 month ago

brunosan commented 3 months ago

The main practical use case of Clay as of now, and the center of the upcoming app, is the ability to find similar features. Think: 1) Click on a pool, 2) find more potential examples, 3) confirm/reject candidates, 4) iterate until you are happy.

The current chip size 512 pixels or ~5120 m in Sentinel is much larger than most semantics, or even the patch size 32 pixels or ~320 meters so we corresponding embeddings will incorporate the many semantics present on the chip/patch. This multi-semantics will lead to similarity search (e.g. cosine) or other tools of limited use, since this looks at all dimensions.

I believe we need a way to both: 1) pick the subset, or function, of dimensions that best represent the requested sample of examples. 2) locate them in the image.

This might take the shape of a "decoder" that is either plug to the encoder or, better, take embeddings as input. Ideally, this decoder is agnostic of the label, or location, and needs no trainning on inference time (so that the app can use it easily).

cc @yellowcap, @geohacker and @srmsoumya for ideas.

MaceGrim commented 3 months ago

Let me layout my thought process and what I've tried. We're in the wild-west here and I think it's good to layout all assumptions and decisions. I'm specifically exploring Aquaculture in Bali in this

Given:

  1. Data points X (Patch Level) with: Clay Embeddings Geometry Aquaculture (Binary Value, Yes if aquaculture is present)
  2. Threshold, T
  3. Strategy St to surface new points. Should use at least one of: Similarities between embeddings (cosine probably makes the most sense) (Easily provided by Qdrant and similar vector search engines) Model to classify points as either positive or negative

Perform the following procedure:

  1. Select Sp starting positive points (sampled in our simulation, chosen by users in the app)
  2. Select Sn starting negative points (again, sampled in the simulation, but chosen by users in the app)
  3. Initialize recall = Sp / (total count of positive points), recall is (count of labeled positive points / total count positive points)
  4. While recall < T (we want "at least 95% of points, for example, T would be 0.95) a) Use St(Sp, Sn, X, maybe some model) to find new points for the user to label b) User provides the labels for the surfaced points (simulation can easily provide the correct labels for new points. It might be valuable to add noise to the labels for testing) c) Positive points are added to Sp and Negative points are added to Sn
  5. Return Total Number Tn of Iterations Required to Reach Threshold

There will be a distribution over Tn, so we'll run this procedure multiple times for each Strategy. We should be able to see which strategies are able to surface the most relevant results the fastest with this setup, and we can provide metrics on how fast. For starters, I've implemented a few strategies in Qdrant, and I'll make this code available on a benchmarking branch very soon.

My current Strategies are:

  1. RandomSamplePositiveExamples: Surface the closest 25, unlabeled points for a random positive point in Sp
  2. CreateRepresentativePoint: Take the average embedding of all points in Sp and find the closest 25 unlabeled points to this new, representative, point
  3. KNN: Find the nearest 25 points for all points in Sp, and then sort by distance to closest point. Surface the top 25. (NOTE): I am NOT using negative examples in this setup yet.

In a (currently buggy?) version these results perform as indicated in the following chart.

The Y-Axis is Recall The X-Axis is the number of iterations that have elapsed The darkly colored lines are the averages of the strategies at that iteration (I should probably flip this to the Y axis) The light lines are actual results from a single sim

clay_sim_search_strategies_initial

The chart suggests that the "Representative" strategy is currently the best. I'm also adding 10 labels at LEAST for every iteration to test some things and to get a reasonable speed through each simulation. This should be removed later.

Things are a little loose right now, but I'm sharing in the interest of moving a little faster. Any comments or questions are absolutely welcome!

brunosan commented 3 months ago

Thanks for the note. I need to think a bit more about what you write (can do put the code on a branch?) but i makes sense.

This is what I'm trying, but not really results to show yet.

  1. "Common similar": Given "P" positive locations, find the "C" closest by cosine similarity for each location. Then retain the most common positive ones. This gives the most common closest samples.

    The rational is to assume that embeddings are polysemantic, so the similarity could be to our intended feature, or another one. By retaining the common closest, we filter out to the intended semantic.

  2. "Common contrastive": Tweak the above by adding N negative locations, finding the "C" closest by similarity search of these negatives, and remove the common occurrences from the positive candidates on 1.

    The rationale is that by selecting negatives, we instruct the model that the semantics there are not intended, so the closest tiles are likely to also have these unintended semantics, and we remove them. This should work best with user-defined negative cases, where humans understand that the feature presented might look similar to aquaculture, but is not, refining the useful semantics.

  3. "Polysemantic Pruning". Take the positve and negative samples given and train a random forest. Then calculate feature important for each dimension of the embedding and remove the x% least important. Then do steps 1. or 2. with the pruned embeddings.

    The rationale is that RF importance will tell you what dimensions (semantics) are relevant and which ones are not. By removing the least important, we restrict the cosine similarity to dimensions relevant to the feature we want.

For each approach I want to explore the combined effect of more samples vs more prunning.

Some early tests:

Hypothesis: Pruning up to 50% might make sense.

image

This test the pruning strategy plus RF. X axis is steps removing dimensions based on least importance as we get more positive examples). Y is accuracy. Therefore this plot shows the combined effect of more samples and more prunning. I should test retaining same amount of pruning, or no pruning and increase samples, but not done yet.

Hypothesis: Pruning removes irrelevant dimensions

image

X axis is dimension index (sorted by feature importance). Y axis is feature importance for RF. Clearly more than half of the dimensions have no importance for the RF. In this sample, the red dashed line indicates how much we remove.

I'm not getting sensible results, so I'm debugging upstream.

Some of the positive aquaculture locations might not be accurate?

Screenshot 2024-03-22 at 11 02 43

Both green BBOXes are labeled as aquaculture on the set. We should check that indeed aquacutlure assets are visible on the data used for inference.

I'm using Python Folium and Esri.Satellite as tiles to quickly visualize the BBOXES and labels. Of course the images are NOT the same as the ones used for inferece on Clay or the ones used to create the labels. If the labels are not visual on the Clay dataset, it makes sense that our apporaches don't work well. I'm working to visualize exactly what Clay used for inference.

MaceGrim commented 3 months ago

Part of the issue is which chips I'm including as positives. I think the original dataset includes those chips with ANY overlap of the aquaculture polygons, but we should check for at least some %chip coverage probably. Let me change that and send it back out.

yellowcap commented 3 months ago

Did you try to use eculidean similarity as well? Not sure what is better. Making similarity searches using multiple dimension and vector aritmetics is an unsolved problem in my opinion. It is still a lot of trial and error, and the search queries for combining vectors are often done on the average / sum / substraction of vectors. But that could become more sophisticated potentially.

Maybe @leothomas can chime in here too.

brunosan commented 3 months ago

Did you try to use eculidean similarity as well?

Indeed trial and error beats theory. I understand eculidean similarity is the "distance between the tips of the arrows". It would hence also suffer from confusion of irrelevant dimensions.

search queries for combining vectors are often done on the average / sum / substraction of vectors.

I think embeddings average doesn't work for EO embeddings. It works for tokens of words as they are mostly monosemantic, so an average highlight the traits of a singular semantic (e.g. king - men + woman. King seems logical to depend on the semantic of country, men, royal, ... so when you substract men and add women, it seems logical to arrive close to queen.

On Remote sensing, however, we create a token per image. If one chip has a green field, and a blue pool, the embedding would need to contain all that. Now imagine it's a lone house in a field, the average embedding of the area will greatly difuse the blue pool. In fact, in our v0.1 we will greatly difuse that pool on the RGB average since it will only appear on the blue band.

Polysemtantic prunning then is the method to remove that green field semantics from the embedding, so that the cosine or euclidean similairty indeed measure the distance in dimensions that measure what we care about.

lauracchen commented 3 months ago

https://arxiv.org/abs/2003.04151 "Embedding Propagation: Smoother Manifold for Few-Shot Classification"

They claim embedding propagation can help with issues of "distribution shift" where the training data isn't distributed similarly to test set. Not sure if this could be helpful or how easily this could be applied for geospatial, but if it works, perhaps an alternative to having to continue model training on a specific region like Bali for instance?

brunosan commented 3 months ago

Tagging here a similar issue with quarries and not getting expected semantic similarity. https://github.com/Clay-foundation/model/discussions/140#discussioncomment-9015252

brunosan commented 3 months ago

Posting here some general notes as we explore this issue:

brunosan commented 3 months ago

Basically, after A LOT of fancy exploring, the most effective is simply cosine similarity.

More details:

yellowcap commented 1 month ago

Closing as out of date, feel free to re-open if appropriate.