huggingface / pytorch-image-models

The largest collection of PyTorch image encoders / backbones. Including train, eval, inference, export scripts, and pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNetV4, MobileNet-V3 & V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeXt, and more
https://huggingface.co/docs/timm
Apache License 2.0
31.69k stars 4.71k forks source link

[BUG] Transform not properly working with batching and Grayscale Images in ViT #2219

Closed asusdisciple closed 2 months ago

asusdisciple commented 3 months ago

Describe the bug If you use the out of the box image transforms for TinyViT, they do not really work if you want to use a grayscale images, because they expect 2/3 channel images (see code below). It would also be nice if you could use batching right away since it increases the interoperability with other workflows and frameworks (lightning for example).

To Reproduce Steps to reproduce the behavior:

model = timm.create_model('tiny_vit_21m_512.dist_in22k_ft_in1k', pretrained=True, in_chans=1, num_classes=4)
data_config = timm.data.resolve_model_data_config(model)
    train_transform = timm.data.create_transform(**data_config, is_training=True)
 train_transform.transforms.insert(0, v2.ToPILImage())
x = torch.randn(4, 1, 512, 512)
model(x)

Expected behavior Expected behaviour would be to automatically detect a 4D tensor and implement batching. Also for grayscale images the channel could just be copied of all 3 channels before transformation for a low effort implementation. There is also the PIL Images dependency, maybe it makes sense to drop that in favor of torch.tensor.

Screenshots If applicable, add screenshots to help explain your problem.

Desktop (please complete the following information):

rwightman commented 3 months ago

the request here isn't clear