neuropoly / idea-projects

Ideas for cool projects
1 stars 0 forks source link

Learning multi-task segmentation by training on a dataset of nnUNet checkpoints #29

Open naga-karthik opened 2 months ago

naga-karthik commented 2 months ago

For conventional DL-based computer vision tasks, we have a dataset of images and their corresponding labels, we train a model based on a loss function to output whatever we require (classes, segmentation maps, etc.). In the process of training and hyper-parameter optimization, we train a hundreds of models (each with their corresponding checkpoints i.e. .pt files). This paper introduces the idea of training a diffusion model on neural network (NN) checkpoints. The authors first create a dataset of thousands of checkpoints by pre-training on an image classification task. Then, using the NN checkpoints as input, they train a diffusion model to generate optimized parameters of the network. At test time, by providing the desired loss value and a set of randomly-initialized parameters, the diffusion model generates a set of optimized NN parameters in one shot, which essentially solves the downstream classification task.

Example: Say the task is spinal cord segmentation. Typically, we choose an optimizer (i.e Adam) and train for 1000 epochs. At the end, we obtain a set of (optimized) parameters of the model, which we then use for inference. In contrast, with the paper's approach, we obtain these optimized parameters straight from the diffusion model (which has been trained on NN checkpoints stored during the pre-training phase of SC segmentation). So, in one step (or, one NN update), we could have a SC segmentation model.

How to start: Because we have nnUNet models for many of our (binary) segmentation tasks, we could pick 2-3 easy tasks and generate, say, 500-1000 checkpoints for these tasks (the pre-training phase to store random checkpoints is described in the paper). The tasks could be SC segmentation, canal segmentation and/or lesion segmentation. Then, using the paper’s open-source code, we could train the diffusion model on nnUNet checkpoints. The idea is that because we are using the checkpoints that were obtained from SC/lesion/canal segmentation, the diffusion model can find an optimal set of parameters that could solve all these tasks at once.

Important: One of the limitations of the above paper is that their approach merely acts as a weight-space interpolation method as the diffusion model is bounded by the best checkpoint in the training set (i.e. it is unable to generate NN parameters that result in lower loss/higher Dice than the best values in the training set). Apart from this major limitation, the concept of training on a dataset of checkpoints seems really cool!

Why pursue this? While we are able to obtain good segmentation performance by conventional NN training methods, I think that this is still limited by the (user-defined) choice of hyperparameters. And, we still have different models for different tasks. If the diffusion model is indeed only doing a fancy weight-space interpolation, maybe it will be able to generate a set of parameters that unifies these tasks and obtain a single nnUNet that solves these tasks?