Closed alelotti96 closed 2 years ago
Did not tested on TPU. But if you have some troubles and can provide colab notebook with demo, i can try to help.
Actually, I was able to run the model with keras.fit but I still face the error in my custom training loop (which works with another model in TF 2.4.1). Is there a way to import this models in TF 2.4.1?
Current implementation should be fully portable with earlier tf2 versions.
There are may be 2 differences: 1) Namespace paths (e.g. some layers may be experimental in earlier versions, keras move them sometimes) 2) Before tf 2.7 keras and tf.keras may have different implementations. So it may be needed to swap all paths from keras to tf.keras (depends on what you used)
Ok, thank you for the support, I0m going to try right now and let you know. One last question: since I need to process grayscale images, is it possible to reduce the number of channels to 1 even when using pretrained models? This would reduce significantly memory consumption.
Ok thank you, for the moment I will stick to point 2. I was able to import the model in TF 2.4.1 by modifying import statements and removing tha call to "validate_activations" in model.py (It should not be a problem as I am not loading the head of the network).
In TF 2.4.1 I get a more meaningful error:
_(0) Invalid argument: {{function_node __inference_train_step_190757}} Compilation failure: Input to reshape is a tensor with 756 values, but the requested shape has 288 [[{{node while/body/_1/while/model_1/swin_base_384/layers.0/Repeat/Reshape}}]] TPU compilation failed [[tpu_compile_succeeded_assert/_15792548884583901316/_6]] [[while/LoopCond/_1050/72]]
This happens at the beginning of the second epoch. I was wondering if this could be related to the dataset batching but I used the exact same code for loading images and labels few days ago to train another model, so it seems strange. My images are 384x384x3 and I'm using a batch size of 48. The error does not arise with keras.fit... but I need a custom training loop. Do you have any suggestion with respect to this error?
I forgot to mention that I'm using "SwinTransformerBase384".
Did you use model for classification or as feature extractor for semantic segmentation & etc.?
Could you try to use fixed input shape? E.g.
inputs = layers.Input(shape=(384, 384, 3), dtype='uint8') # fix shape here
outputs = layers.Lambda(preprocess_input)(inputs)
outputs = SwinTransformerTiny224()(outputs)
...
model = models.Model(inputs=inputs, outputs=outputs)
Sorry for the delay. Finally, I was able to train with custom input size in TF 2.8.0 (in 2.4.1 I was getting errors like the one above even with fixed input shape). The problem was that the size of my dataset was not fully defined: I added "drop_reminder" to batch and fixed image sizes (now I don't know why it was working before with my other model). Thank you very much for your help. One last question, how does the model handles multi-scale inference? I see I can let the input shape to None and feed it with images of different size.
In the aspect of variable spatial shapes it should work well (open an issue otherwise). And it should work well if your sizes >= pretrain_size * 2
I have some doubts about what will happen when one of input sizes in train/eval will be in range [pretrain_size; 2 * pretrain_size). In the original code for classification last swin block in the last basic block never does shift. But in segmentation part of original repository it does. So i writed a clause to mimic both of them, but not tested it https://github.com/shkarupa-alex/tfswin/blob/ba9f5c8bb4848bb07da1758eb3b22c2d86df8607/tfswin/swin.py#L59
Ok, now I'm retraining the SwinBase384 model on 448x448 px images. I'm using it as a feature extractor basically. Is there the risk that this resolution is not going to give good results because it's not >= pretrain_size * 2?
I think it will be ok. Input to window attention is padded to be divisible by window size.
Hi and thank you for this amazing work. I was wondering if the models are compatible with Colab TPU since the training fails with error " Socket closed".