Autoencoding images with any number of colour channels using models that only accept images with 3 colour channels
Why
General method for using any backbone is valuable because it means we can use pretrained weights and avoid modifying models
How
Current implement only works with 1 colour channel
For resnet50 you need 3, so we repeat the 1 channel along the C dim to make the tensors correct for the resnet
3 colour channels in generally the common input for backbones, so it's worth while trying to fit 1C into 3C
For any number of colour channels thats a bit harder,
In the case of 5 colour channels the trick here is to put the additional colour channels in the batch dimension which can generally be flexible in pytorch
i.e.
tensor in is ->(b,c,y,x)
you reshape it to ->(b*c,1,y,x)
you repeat along c as above -> (b*c,3,y,x)
This tensor will go through resnet50 then, all good
The trick then is to add a loss term on so encourage colours from the same image to be have the same latent represenation
To do this i've implemented euclidean distance and kl divergence losses
The trick being that you need to access your latent representation -> (b*c,z),
add the colour dim back by unsqueezing -> (b*c,1,z)
reshape -> (b,c,z)
and then do whatever operation you like to reduce the error along c
For instance one implementation here is making a distance matrix
so you take your reshape tensor ->(b,c,z)
add a dim ->(b,1,c,z)
tranpose dim 1 and 2 ->(b,c,1,z)
and feed those into pytorch cdist i.e. -> cdist((b,1,c,z),(b,c,1,z)) -> (c,c) (you have to permute the dims a bit to get cdist to work)
KL div is also useful for the varational models but I've not played around with that as much
What:
Autoencoding images with any number of colour channels using models that only accept images with 3 colour channels
Why
General method for using any backbone is valuable because it means we can use pretrained weights and avoid modifying models
How
Current implement only works with 1 colour channel For resnet50 you need 3, so we repeat the 1 channel along the C dim to make the tensors correct for the resnet 3 colour channels in generally the common input for backbones, so it's worth while trying to fit 1C into 3C
For any number of colour channels thats a bit harder,
In the case of 5 colour channels the trick here is to put the additional colour channels in the batch dimension which can generally be flexible in pytorch
i.e.
tensor in is ->
(b,c,y,x)
you reshape it to ->(b*c,1,y,x)
you repeat along c as above ->(b*c,3,y,x)
This tensor will go through resnet50 then, all good
The trick then is to add a loss term on so encourage colours from the same image to be have the same latent represenation To do this i've implemented euclidean distance and kl divergence losses
The trick being that you need to access your latent representation ->
(b*c,z),
add the colour dim back by unsqueezing ->(b*c,1,z)
reshape ->(b,c,z)
and then do whatever operation you like to reduce the error along c
For instance one implementation here is making a distance matrix
so you take your reshape tensor ->
(b,c,z)
add a dim ->(b,1,c,z)
tranpose dim 1 and 2 ->(b,c,1,z)
and feed those into pytorch cdist i.e. ->cdist((b,1,c,z),(b,c,1,z))
->(c,c)
(you have to permute the dims a bit to get cdist to work)KL div is also useful for the varational models but I've not played around with that as much