rishigami / Swin-Transformer-TF

Tensorflow implementation of Swin Transformer model.
Apache License 2.0
198 stars 46 forks source link

Added 3D support for SwinTransformerModel, ie for medical imaging tasks #19

Closed MohamadZeina closed 1 year ago

MohamadZeina commented 1 year ago

Tested and working, ie:

IMAGE_SIZE = [112, 112, 112]
NUM_CLASSES = 10

model_3d = tf.keras.Sequential([
  swin_transformer_nd.SwinTransformerModel(img_size=IMAGE_SIZE, patch_size=(4, 4, 4), depths=[2, 2, 6]),
  tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
])
model_3d.compile(tf.keras.optimizers.Adam(), "categorical_crossentropy")

for i in range(100):
    x = np.zeros([1, *IMAGE_SIZE, 1])
    y = tf.zeros([1, NUM_CLASSES])

    model_3d.fit(x, y)
    print("Trained on a batch")