feevos / resuneta

mxnet source code for the resuneta semantic segmentation models
Other
117 stars 30 forks source link

The code to train the net #3

Closed meule closed 4 years ago

meule commented 4 years ago

Thank you so much for the repo! It's amazing work. Unfortunately, I can't train the multitasking nets. Could you please share the code to train the net?

Thanks in advance

feevos commented 4 years ago

Hi @meule

The code to train the network is standard, as in all mxnet applications. You can find a crash course on mxnet, and examples, here. In particular, for training a network you can see this.

One important thing is to create a loss that is suitable for the list of predictions.

This is something that is working, by feeding a list of predictions _prediction (in the same context) and an nd.array of labels _label. Note that in order for this loss to work you need to feed it appropriate input _label and that depends on how you define your dataset - which is a custom operation. Some people may decide to feed as input multiple arguments. I just stack on the same _label all requested outputs.

from resuneta.nn.loss.loss import * 

class CustomLoss(object):
    def __init__(self,NClasses=6):
        self.tnmt = Tanimoto_wth_dual()
        self.skip = NClasses 

    def loss(self,_prediction,_label):
        pred_segm  = _prediction[0]
        pred_bound = _prediction[1]
        pred_dists = _prediction[2]

        # HSV colorspace prediction
        pred_color = _prediction[3]

        # Here I split _label to all different labels, to apply different loss functions on each type of output. 
        # Note that I pack in the _label, the following information:
        # First NClasses are segmentation
        # Second set is boundary
        # third set is distance transform
        # Last three elements are the original image in HSV color space 
        label_segm  = _label[:,:self.skip,:,:] # segmentation 
        label_bound = _label[:,self.skip:2*self.skip,:,:] # boundaries 
        label_dists = _label[:,2*self.skip:-3,:,:]  # distance transform 

        # color image -HSV format - need to transform to HSV!!
        label_color = _label[:,-3:,:,:] # color in HSV 

         # Getting all loss functions for each task, all using the SAME Tanimoto with Complement 
        loss_segm = 1.0 - self.tnmt(pred_segm,   label_segm)
        loss_bound = 1.0 - self.tnmt(pred_bound, label_bound)
        loss_dists = 1.0 - self.tnmt(pred_dists, label_dists)

        loss_color = 1.0- tnmt(pred_color,label_color)

        # Devide by 4.0 to keep output in range [0,1] 
        return (loss_segm+loss_bound+loss_dists+loss_color)/4.0 

The training routine, once you've initialized your network and trainer is "simple" (it gets much more involved if you want to add checkpointing, monitoring operations and distributed training):

# Define network, dataset, data generator
from resuneta.models.resunet_d7_causal_mtskcolor_ddist import *
# modify according to your preferences
Nfilters_init = 32
NClasses = 6
net = ResUNet_d7(Nfilters_init,NClasses)
net.initialize()
net.hybridize()

# trainer/optimizer
 # Some optimizer of your choice, recommend Adam. 
trainer = gluon.trainer.Trainer(net.collect_params(), 'adam') # add appropriate parameters. 

YourLoss = CustomLoss()

# define your custom dataset
dataset = ...# 
datagen = gluon.data.DataLoader(dataset,batch_size=YourBatchSize,shuffle=True) # see [here](https://beta.mxnet.io/api/gluon/_autogen/mxnet.gluon.data.DataLoader.html)

for epoch in range(epochs): # Train for as many epochs you want 
    for img, mask in datagen: # assumes a single gpu 
        img  = img.as_in_context(mx.gpu())
        mask = mask.as_in_context(mx.gpu())
        with autograd.record():
            ListOfPredictions = net(img)
            loss = YourLoss(ListOfPredictions,mask)
        loss.backward()
        trainer.step(SomeBatchSize)
        # Add here any kind of monitoring you want

For distributed training with mxnet there are a lot of options depending on the cluster you have at your disposal. You can find a starting point here (I highly recommend go with the horovod approach).

Hope the above helps.