google-research / vision_transformer

Apache License 2.0
10.48k stars 1.3k forks source link

Problem with kmnist dataset #273

Open bukson opened 1 year ago

bukson commented 1 year ago

Hello

I am trying to use pretrained B_16 model on tfds kmnist dataset (which is similar to mnist in terms of 26x26 greyscale)

Problem is I got error

Initializer expected to generate shape (16, 16, 3, 768) but got shape (16, 16, 1, 768) 

Which is probably due to only 1 color channel instead of 3.

I had no problem with running pretrained model on custom color dataset, is this method only available for 3 channel datasets, or mnist likes are also welcome?

andsteing commented 1 year ago

I would simply repeat the channels here:

https://github.com/google-research/vision_transformer/blob/297866ab49341257e6f657d7f1068164c8eaf338/vit_jax/input_pipeline.py#L195-L216

something like

import tensorflow_datasets as tfds
import tensorflow as tf

ds = tfds.load('mnist', split='train')
ds = ds.map(lambda d: {
    'label': d['label'],
    'image': tf.repeat(d['image'], 3, axis=2),
})
ds = ds.batch(2)
b = next(iter(ds))
assert b['image'].shape.as_list() == [2, 28, 28, 3]