insitro / ChannelViT

Channel Vision Transformers: An Image Is Worth C x 16 x 16 Words
https://arxiv.org/abs/2309.16108
Other
46 stars 6 forks source link

KeyError: 'channels' #10

Closed zhulinchng closed 1 month ago

zhulinchng commented 1 month ago

I'm trying to customize and use the backbone channelvits and hcs_channelvits, but there seem to be no reference to the 'channels' in the forward pass of the PatchEmbedPerChannel class. What is the value for extra_tokens["channels"] supposed to look like?

    def forward(self, x, extra_tokens={}):
        # assume all images in the same batch has the same input channels
        cur_channels = extra_tokens["channels"][0]

https://github.com/insitro/ChannelViT/blob/1077a103e30cf20b6223b64935666962d0e2e836/channelvit/backbone/channel_vit.py#L56

    def forward(self, x, extra_tokens={}):
        # # assume all images in the same batch has the same input channels
        # cur_channels = extra_tokens["channels"][0]
        # embedding lookup
        cur_channel_embed = self.channel_embed(
            extra_tokens["channels"]
        )  # B, Cin, embed_dim=Cout

https://github.com/insitro/ChannelViT/blob/1077a103e30cf20b6223b64935666962d0e2e836/channelvit/backbone/hcs_channel_vit.py#L61

srinivasans-insitro commented 1 month ago

@zhulinchng extra_tokens["channels"] should contain the indices of channels included in the batch and should be of shape batch_size x n_channels.

The following thread has additional details and links to examples for setting it correctly https://github.com/insitro/ChannelViT/issues/3#issuecomment-2027716674