shkarupa-alex / tfswin

Keras (TensorFlow v2) reimplementation of Swin Transformer V1 and V2 models
MIT License
20 stars 3 forks source link

Colab TPU compatibility #2

Closed alelotti96 closed 2 years ago

alelotti96 commented 2 years ago

Hi and thank you for this amazing work. I was wondering if the models are compatible with Colab TPU since the training fails with error " Socket closed".

shkarupa-alex commented 2 years ago

Did not tested on TPU. But if you have some troubles and can provide colab notebook with demo, i can try to help.

alelotti96 commented 2 years ago

Actually, I was able to run the model with keras.fit but I still face the error in my custom training loop (which works with another model in TF 2.4.1). Is there a way to import this models in TF 2.4.1?

shkarupa-alex commented 2 years ago

Current implementation should be fully portable with earlier tf2 versions.

There are may be 2 differences: 1) Namespace paths (e.g. some layers may be experimental in earlier versions, keras move them sometimes) 2) Before tf 2.7 keras and tf.keras may have different implementations. So it may be needed to swap all paths from keras to tf.keras (depends on what you used)

alelotti96 commented 2 years ago

Ok, thank you for the support, I0m going to try right now and let you know. One last question: since I need to process grayscale images, is it possible to reduce the number of channels to 1 even when using pretrained models? This would reduce significantly memory consumption.

shkarupa-alex commented 2 years ago
  1. This will not reduce memory and computations much. To do that you need to reduce 4-d conv kernel tensor along channel dimension from 3 to 1. And this is just 1 layer (most of memory and computations are in swin blocks - mlp, attention)
  2. The easiest way to use grayscale images - repeat it along channel to get 3-channel image
  3. Another way is to sum with (keep_dims=True) patch embedding conv kernel along channel dim. Try to find how to do that for ResNet50, for swin models in this repo algorithm will be the same
alelotti96 commented 2 years ago

Ok thank you, for the moment I will stick to point 2. I was able to import the model in TF 2.4.1 by modifying import statements and removing tha call to "validate_activations" in model.py (It should not be a problem as I am not loading the head of the network).

In TF 2.4.1 I get a more meaningful error:

_(0) Invalid argument: {{function_node __inference_train_step_190757}} Compilation failure: Input to reshape is a tensor with 756 values, but the requested shape has 288 [[{{node while/body/_1/while/model_1/swin_base_384/layers.0/Repeat/Reshape}}]] TPU compilation failed [[tpu_compile_succeeded_assert/_15792548884583901316/_6]] [[while/LoopCond/_1050/72]]

This happens at the beginning of the second epoch. I was wondering if this could be related to the dataset batching but I used the exact same code for loading images and labels few days ago to train another model, so it seems strange. My images are 384x384x3 and I'm using a batch size of 48. The error does not arise with keras.fit... but I need a custom training loop. Do you have any suggestion with respect to this error?

alelotti96 commented 2 years ago

I forgot to mention that I'm using "SwinTransformerBase384".

shkarupa-alex commented 2 years ago

Did you use model for classification or as feature extractor for semantic segmentation & etc.?

Could you try to use fixed input shape? E.g.

inputs = layers.Input(shape=(384, 384, 3), dtype='uint8') # fix shape here
outputs = layers.Lambda(preprocess_input)(inputs)
outputs = SwinTransformerTiny224()(outputs)

...

model = models.Model(inputs=inputs, outputs=outputs)
alelotti96 commented 2 years ago

Sorry for the delay. Finally, I was able to train with custom input size in TF 2.8.0 (in 2.4.1 I was getting errors like the one above even with fixed input shape). The problem was that the size of my dataset was not fully defined: I added "drop_reminder" to batch and fixed image sizes (now I don't know why it was working before with my other model). Thank you very much for your help. One last question, how does the model handles multi-scale inference? I see I can let the input shape to None and feed it with images of different size.

shkarupa-alex commented 2 years ago

In the aspect of variable spatial shapes it should work well (open an issue otherwise). And it should work well if your sizes >= pretrain_size * 2

I have some doubts about what will happen when one of input sizes in train/eval will be in range [pretrain_size; 2 * pretrain_size). In the original code for classification last swin block in the last basic block never does shift. But in segmentation part of original repository it does. So i writed a clause to mimic both of them, but not tested it https://github.com/shkarupa-alex/tfswin/blob/ba9f5c8bb4848bb07da1758eb3b22c2d86df8607/tfswin/swin.py#L59

alelotti96 commented 2 years ago

Ok, now I'm retraining the SwinBase384 model on 448x448 px images. I'm using it as a feature extractor basically. Is there the risk that this resolution is not going to give good results because it's not >= pretrain_size * 2?

shkarupa-alex commented 2 years ago

I think it will be ok. Input to window attention is padded to be divisible by window size.