Clay-foundation / model

The Clay Foundation Model (in development)
https://clay-foundation.github.io/model/
Apache License 2.0
299 stars 38 forks source link

Add a linear layer to squeeze all patch embeddings into a single image embedding? #107

Closed brunosan closed 3 months ago

brunosan commented 8 months ago

The current bottleneck of the Unet architecture are patch embeddings, for each section of the image. When we create the embedding of the image, we use an average of all patch embeddgins.

However, this approach is very lossy an smooth version of the image semantics, specially when trying to reconstruct an image from its embedding (location and time). Moreover, each patch embedding does not include the semantics of the patch, rather is trained to capture the self-attention weighted semantics of all other available patches within the image. This also makes the collection of patch embeddings highly redundant.

Can we Introduce one more feedforward layer to aggregate all patch embeddings, location, and time data into a single, image-wide embedding? (one on the encoder to go down from patch semantics into image semantics, and one the decoder to expand from image semantics into patch semantics)

This embedding would encapsulate the entire semantic context of the image at a specific location and time. It would also allow us to reconstruct the image from the image embedding. I suspect it will also need to convey where in the image the semantics are, so the decoder can place those differences within the image correctly. This is also highly desirable for downstream tasks that need to locate inter-image semantics.

yellowcap commented 7 months ago

We can consider to store the raw encoder output alongside with the average embeddings. With the raw encoder output any kind of re-combination, including using a linear layer for creating the "best" combination, can be performed.

brunosan commented 7 months ago

Tagging here something I learned today, which is that for each self-attention patch, the 13 layers are grouped into 6 groups, and we create one embedding per group.

I still think it would make sense to roll all layers group into a single embedding, at the self-attention patch, and use that as the end of the encoder.

brunosan commented 5 months ago

I keep coming up with the need to do this. @yellowcap and I had a great conversation where we decided NOT to do this for now.

The reason being that currently these band groups represent a wider, capacity to learn features, and grouped by band (e.g. optical features, DEM features, SAR features). This is much richer than just one vector. We don't know if squeezing it further looses quality, and answering that question seems lower priority right now (versus the cost of extra space for the embeddings).

If we do need to squeeze to one vector, we can always make a "decoder" that squeezes these into one vector while keeping most of the information (loss function TBD). We can also take the embeddings instead of the econder.

It is true that doing it this way creates the need to average the patch-level embeddings, or it creates 16x sizes of patch-level embedding, which sometimes we might need, like in #168.

brunosan commented 4 months ago

Bumping this up again.

Specially as we move away from fixed bands and band groups, and we focus on e.g. similarly search at the patch level, it seems critical that we do not create embeddings at the neck of the Unet per patch AND per band.

1)how would that even work when the input data can have different bands? 2) when we do an average across bands to create the patch embedding, we are forcing a brute force reduction, of all independent bands into single vector. May be "houses" is a semantic on some dimension on one band, but other dimensions in another, and making the average doesn't even make sense semantically.

I propose (again) the reduce now the neck of the Unet to one embedding per patch.

Cc @yellowcap @srmsoumya

brunosan commented 4 months ago

Talking with @yellowcap, it seems we are already merging all bands into a single patch embedding.

@srmsoumya to confirm and close here.

yellowcap commented 3 months ago

Closing as out of date, feel free to re-open if appropriate. We have a class token in v1, which represents a learned way to compress the band embeddings into a single embedding. So that adderesses the issu.e