Open robmarkcole opened 10 months ago
As the model is big I decided to use dropout to overcome overfitting I checked 0.1, 0.15, 0.20 and 0.25 and decided to use 0.15. If you want to optimize this it is better to use Dropout2D instead. Conventional dropout with randomly mask some of the pixels in the feature maps based on the drop rate probability which might break the spatial relationships but Dropout2D instead will mask the whole feature maps which is more suited to CNN models where the spatial relationships are crucial.
Sharpness-Aware Minimization (SAM) Optimizer helped a lot in stabilizing the loss curve another alternative that I have not tested might be AdamW but it needs tuning the weight decay parameter.
For the learning rate policy, I have tested "StepLR" and "MultiStepLR" (both with the default setting in the repo) beside the PolynomialLR but liked the latter better. Unfortunately, I did not get a chance to try more policies or extensively play with the configuration of the ones that I tested.
For loss function, I only tested "CrossEntropy" besides the "TverskyFocalLoss" and the latter was performing much better. TFL has two important parameters alpha which controls the trade-off between precision and recall (as the value gets higher the loss will weigh false positives more heavily) and gamma which down-weights easy-to-classify samples and focuses more on harder ones. The chosen value of 0.9 was out on previous experience with land cover mapping with Landsat and in this context is only tested against the default value of 1.33.
There is also the possibility of including a simple additive co-attention on the skip connection of the Unet model. I tested it only once during the early stages of the training experiment when we were testing a 17-class segmentation and decided not to use it as the model was not performing well in some of the classes. After changing the dataset to the current version I did not get a chance to try it again as it requires optimizing the other hyper-parameters as well.
@samKhallaghi many thanks on the details - lots of avenues to explore
From the following config I assume there was a fair bit of experimentation performed to arrive at these parameters - are you able to shed light/into on the experiments run? I am seeking to compare Prithvi/Unet when both typical defaults are used, and when optimised. Many thanks