uhlmanngroup / bioimage_embed

3 stars 3 forks source link

N colour channel autoencoding #37

Closed ctr26 closed 5 months ago

ctr26 commented 10 months ago

What:

Autoencoding images with any number of colour channels using models that only accept images with 3 colour channels

Why

General method for using any backbone is valuable because it means we can use pretrained weights and avoid modifying models

How

Current implement only works with 1 colour channel For resnet50 you need 3, so we repeat the 1 channel along the C dim to make the tensors correct for the resnet 3 colour channels in generally the common input for backbones, so it's worth while trying to fit 1C into 3C

For any number of colour channels thats a bit harder,

In the case of 5 colour channels the trick here is to put the additional colour channels in the batch dimension which can generally be flexible in pytorch

i.e.

tensor in is ->(b,c,y,x) you reshape it to ->(b*c,1,y,x) you repeat along c as above -> (b*c,3,y,x)

This tensor will go through resnet50 then, all good

The trick then is to add a loss term on so encourage colours from the same image to be have the same latent represenation To do this i've implemented euclidean distance and kl divergence losses

The trick being that you need to access your latent representation -> (b*c,z), add the colour dim back by unsqueezing -> (b*c,1,z) reshape -> (b,c,z)

and then do whatever operation you like to reduce the error along c

For instance one implementation here is making a distance matrix

so you take your reshape tensor ->(b,c,z) add a dim ->(b,1,c,z) tranpose dim 1 and 2 ->(b,c,1,z) and feed those into pytorch cdist i.e. -> cdist((b,1,c,z),(b,c,1,z)) -> (c,c) (you have to permute the dims a bit to get cdist to work)

KL div is also useful for the varational models but I've not played around with that as much