Closed meule closed 4 years ago
Hi @meule
The code to train the network is standard, as in all mxnet applications. You can find a crash course on mxnet, and examples, here. In particular, for training a network you can see this.
One important thing is to create a loss that is suitable for the list of predictions.
This is something that is working, by feeding a list of predictions _prediction
(in the same context) and an nd.array of labels _label
. Note that in order for this loss to work you need to feed it appropriate input _label
and that depends on how you define your dataset - which is a custom operation. Some people may decide to feed as input multiple arguments. I just stack on the same _label all requested outputs.
from resuneta.nn.loss.loss import *
class CustomLoss(object):
def __init__(self,NClasses=6):
self.tnmt = Tanimoto_wth_dual()
self.skip = NClasses
def loss(self,_prediction,_label):
pred_segm = _prediction[0]
pred_bound = _prediction[1]
pred_dists = _prediction[2]
# HSV colorspace prediction
pred_color = _prediction[3]
# Here I split _label to all different labels, to apply different loss functions on each type of output.
# Note that I pack in the _label, the following information:
# First NClasses are segmentation
# Second set is boundary
# third set is distance transform
# Last three elements are the original image in HSV color space
label_segm = _label[:,:self.skip,:,:] # segmentation
label_bound = _label[:,self.skip:2*self.skip,:,:] # boundaries
label_dists = _label[:,2*self.skip:-3,:,:] # distance transform
# color image -HSV format - need to transform to HSV!!
label_color = _label[:,-3:,:,:] # color in HSV
# Getting all loss functions for each task, all using the SAME Tanimoto with Complement
loss_segm = 1.0 - self.tnmt(pred_segm, label_segm)
loss_bound = 1.0 - self.tnmt(pred_bound, label_bound)
loss_dists = 1.0 - self.tnmt(pred_dists, label_dists)
loss_color = 1.0- tnmt(pred_color,label_color)
# Devide by 4.0 to keep output in range [0,1]
return (loss_segm+loss_bound+loss_dists+loss_color)/4.0
The training routine, once you've initialized your network and trainer is "simple" (it gets much more involved if you want to add checkpointing, monitoring operations and distributed training):
# Define network, dataset, data generator
from resuneta.models.resunet_d7_causal_mtskcolor_ddist import *
# modify according to your preferences
Nfilters_init = 32
NClasses = 6
net = ResUNet_d7(Nfilters_init,NClasses)
net.initialize()
net.hybridize()
# trainer/optimizer
# Some optimizer of your choice, recommend Adam.
trainer = gluon.trainer.Trainer(net.collect_params(), 'adam') # add appropriate parameters.
YourLoss = CustomLoss()
# define your custom dataset
dataset = ...#
datagen = gluon.data.DataLoader(dataset,batch_size=YourBatchSize,shuffle=True) # see [here](https://beta.mxnet.io/api/gluon/_autogen/mxnet.gluon.data.DataLoader.html)
for epoch in range(epochs): # Train for as many epochs you want
for img, mask in datagen: # assumes a single gpu
img = img.as_in_context(mx.gpu())
mask = mask.as_in_context(mx.gpu())
with autograd.record():
ListOfPredictions = net(img)
loss = YourLoss(ListOfPredictions,mask)
loss.backward()
trainer.step(SomeBatchSize)
# Add here any kind of monitoring you want
For distributed training with mxnet there are a lot of options depending on the cluster you have at your disposal. You can find a starting point here (I highly recommend go with the horovod approach).
Hope the above helps.
Thank you so much for the repo! It's amazing work. Unfortunately, I can't train the multitasking nets. Could you please share the code to train the net?
Thanks in advance