Open Boom5426 opened 2 months ago
What param shoud I input the extra_tokens?
extra_tokens["channels"]
should contain channel indices per batch and should be of shape batch_size x n_channels.
For example, in the ImageNet dataset, we return a dictionary containing channels per sample which is collated using pytorch default_collate function. default_collate
collates Mapping[K, V_i] -> Mapping[K, default_collate([V_1, V_2, …])]
resulting in extra_tokens['channels']
of shape batch_size x n_channels
.
also discussed in https://github.com/insitro/ChannelViT/issues/3#issuecomment-2027716674
If I just want to use it in other field, what should I do?
import torch model = torch.hub.load('insitro/ChannelViT', 'cpjump_cellpaint_channelvit_small_p8_with_hcs_supervised', pretrained=True) model.eval() images = torch.randn(5, 3, 224, 224) out = model(images)
KeyError: 'channels'