neel-dey / AnyStar

[WACV 2024] AnyStar: Domain randomized universal star-convex 3D instance segmentation
https://arxiv.org/abs/2307.07044
MIT License
11 stars 3 forks source link

Fine-tuning #1

Open postnubilaphoebus opened 3 months ago

postnubilaphoebus commented 3 months ago

Hello, thank you for providing this repository as well as the model weights. I have some confocal data Anystar works reasonably well with out of the box, but I would need some improvement, which I hope to achieve through fine-tuning. On the official Stardist website they claim that fine-tuning is possible, but they have no specific advice for how to do it. Do you have any intuition how many of the last layers to unfreeze or what learning rate to use?

neel-dey commented 3 months ago

Hi @postnubilaphoebus (cool handle, btw),

That's great that it works reasonably well out of the box! Before finetuning, I would do a quick inference-time grid search over n_tiles and scale in the predict_instances function to squeeze more performance out of the base model. The base model works best when the shapes of interest in the prediction tile are sized relatively similarly to the 64^3 crops seen in training (as in Fig 2 and 5A in the paper) and are rescaled to be roughly isotropic in 3D.

For finetuning: in biomedical segmentation (as opposed to natural image classification), finetuning all the layers and using aggressive data augmentation usually works best in my experience. I would not focus on how many layers to unfreeze, but instead on the augmentations and finetuning training duration / n_epochs + n_steps. I would use a validation set to determine training duration, or at least a small held-out image crop if your finetuning set is limited. For the learning rate, I would just start with the default 2e-4 and see what happens (this is very dataset-specific and hard to predict).

Good luck! Let me know how it goes :)

postnubilaphoebus commented 2 months ago

Hi @neel-dey (Thanks, you're the first one to notice :D),

Can you elaborate on why _ntiles and scale matter for training Anystar? When I asked a similar question about Stardist, I was told it does not really affect performance: https://github.com/stardist/stardist-napari/issues/24.

Interesting that biomedical segmentation requires to unfreeze all layers. So far, I tried with two 32³ image crops, got the validation loss to go down, but unfortunately the results aren't great at all upon visual inspection (almost like training ruined it). So I will follow your advice and do aggressive image augmentation, annotate images of the right size (64³) and do a random search for hyperparameters. My images also aren't isotropic (2.3 times worse resolution in z), so some nearest neighbour upscaling will help.

Lastly, how should I handle train_steps_per_epoch? The Stardist code is not entirely clear to me (I come from more of a pytorch background). It seems that length = epochs * train_steps_per_epoch is passed to some training function, which is why I don't get why you need two variables here.

Thanks again for your swift answer and looking forward to your reply.

neel-dey commented 2 months ago

Can you elaborate on why n_tiles and scale matter for training Anystar? When I asked a similar question about Stardist, I was told it does not really affect performance

Not during training, it matters for inference on new datasets and the target instances need to be scaled similarly to the range within the training set. I can't comment on the repo you referenced as this is the first time I've seen it, but in general, DL networks tend to overfit to the image grid size that they were trained on (64^3 in the case of the released weights were) even with fully convolutional nets. But again, try a bunch of different values at inference with the pretrained weights to see what happens :) The best results I got were when I used scale to make the data roughly isotropic and when I set n_tiles such that each tile had a similar density of instances to those in Fig 2 and 5A.

Interesting that biomedical segmentation requires to unfreeze all layers.

It doesn't require it per se, just my experience that unfreezing everything and using intense augmentation generally gets ultimately better performance. But other alternatives could be better! There's no canonical dataset-agnostic answer to this for transfer learning.

So I will follow your advice and do aggressive image augmentation, annotate images of the right size (64³) and do a random search for hyperparameters. My images also aren't isotropic (2.3 times worse resolution in z), so some nearest neighbour upscaling will help

All sounds good. Resampling to isotropic prior to finetuning would be beneficial and I would use linear interpolation for the images and nearest for the label maps.

Lastly, how should I handle train_steps_per_epoch? The Stardist code is not entirely clear to me (I come from more of a pytorch background). It seems that length = epochs * train_steps_per_epoch is passed to some training function, which is why I don't get why you need two variables here.

Sorry, this isn't clear in the code. As we're training on potentially infinite synthetic data, the idea of an "epoch" (one pass over a training dataset) doesn't apply here. So yes, length = epochs * train_steps_per_epoch and the train_steps_per_epoch is just the number of steps taken before calculating validation loss, updating the scheduler if using one, etc.

postnubilaphoebus commented 2 months ago

Thank you, that clarifies things. I will get back to you when I have more interesting results. In the meantime, do you have any advice on speeding up the NMS? It is incredibly slow when you reduce the NMS threshold (I know this problem originates from StarDist and may be hard to solve, but a multi-threading solution or the like is really needed here).

neel-dey commented 2 months ago

It's been a while, but if I remember correctly, NMS with StarDist was indeed multi-threaded for my experiments. This may be an issue with your installation of CSBDeep/StarDist or perhaps a platform restriction. I unfortunately can't help diagnose that one, sorry.

Hope the finetuning goes well!

postnubilaphoebus commented 2 months ago

Hello again, so I finetuned Anystar on six 64³ images with two images in the validation set. Unfortunately, the performance is not quite as good as I would have hoped, even after implementing your suggestions and playing around with both the training and inference parameters. From your experience, how many more images would I need? The Stardist website suggests 5-10 images, but maybe it is different for more challenging cases. It takes me about one day to annotate an image, so I need to weigh the pros and cons here. Roughly, my images contain about 200-300 nuclei per 64-volume. If I am already at the limit, I am afraid I will have to move on to another method.

neel-dey commented 2 months ago

the performance is not quite as good as I would have hoped

Do you mean that it's better or worse after finetuning than the base anystar model that's public and with what finetuning configuration in terms of number of steps etc? It's best if you could post your finetuning code onto a git repo and I can take a look then.

It takes me about one day to annotate an image, so I need to weigh the pros and cons here.

As in, one day to correct the predictions from AS on your volume or one day to annotate from scratch? If it's the latter, your original message mentioned that AS was working reasonably well already, so just correcting its predictions will be much faster.

Roughly, my images contain about 200-300 nuclei per 64-volume.

Can you post a link to/screenshot of these images or email a volume and its corresponding AS outputs to dey@csail.mit.edu? I can take a quick look now to see if there's any obvious change to be made and take a deeper dive if necessary after the NeurIPS deadline next week.