insitro / ChannelViT

Channel Vision Transformers: An Image Is Worth C x 16 x 16 Words
https://arxiv.org/abs/2309.16108
Other
49 stars 6 forks source link

KeyError: 'channels' #12

Open Boom5426 opened 2 months ago

Boom5426 commented 2 months ago

If I just want to use it in other field, what should I do?

import torch model = torch.hub.load('insitro/ChannelViT', 'cpjump_cellpaint_channelvit_small_p8_with_hcs_supervised', pretrained=True) model.eval() images = torch.randn(5, 3, 224, 224) out = model(images)

KeyError: 'channels'

Boom5426 commented 2 months ago

What param shoud I input the extra_tokens?

srinivasans-insitro commented 3 weeks ago

extra_tokens["channels"] should contain channel indices per batch and should be of shape batch_size x n_channels.

For example, in the ImageNet dataset, we return a dictionary containing channels per sample which is collated using pytorch default_collate function. default_collate collates Mapping[K, V_i] -> Mapping[K, default_collate([V_1, V_2, …])] resulting in extra_tokens['channels'] of shape batch_size x n_channels.

also discussed in https://github.com/insitro/ChannelViT/issues/3#issuecomment-2027716674