xamyzhao / brainstorm

Implementation of "Data augmentation using learned transforms for one-shot medical image segmentation"
MIT License
392 stars 91 forks source link

Graph disconnected when training segmentation network #8

Closed steermomo closed 5 years ago

steermomo commented 5 years ago

Hi, Thank you for sharing the code. I run

python main.py trans --gpu 0 --data mri-100unlabeled --model flow-fwd
python main.py trans --gpu 0 --data mri-100unlabeled --model flow-bck

to get the spatial transform models, edit main.py then run

python main.py trans --gpu 0 --data mri-100unlabeled --model color-unet

get the appearance/color transform model. I edit the code at line 265:276, and then train the segmentations network,

python main.py fss --gpu 0 --data mri-100unlabeled --aug_tm

but get Graph disconnected error.

After reading the code, I find that may be the problem in segmenter_model.py, line 417

self.color_aug_model = Model(
                        inputs=self.color_aug_model.inputs[:-1],
                        outputs=[
                            self.color_aug_model.outputs[0],
                            self.color_aug_model.get_layer('add_color_delta').output,
                            self.color_aug_model.outputs[2]],
                        name='color_model_wrapper')

here the input flow_srctotgt is dropped, but the outputs[0], which is transformed_out, still connect to flow_srctotgt by transformed_out = SpatialTransformer(indexing='xy')([transformed_out, flow_srctotgt]) in line 292 at file networks.py. So the program crashed.

According to line 576 in file segmenter_model.py,

color_delta, colored_vol, _ = self.color_aug_model.predict([source_X, X_colortgt_src, source_contours])

Line 417 in file segmenter_model.py should be

self.color_aug_model = Model(
                        inputs=self.color_aug_model.inputs[:-1],
                        outputs=[
                            self.color_aug_model.outputs[1], # edit 0 to 1
                            self.color_aug_model.get_layer('add_color_delta').output,
                            self.color_aug_model.outputs[2]],
                        name='color_model_wrapper')

After change code here, I can train the segment network.

So, is this the correct way to fix that problem? Thanks.:thumbsup: #### My enviroment is:

OS: Ubuntu 18.10 x86_64
Python version: Python 3.6.7 :: Anaconda, Inc.
tensorflow-gpu version: 1.12.0
Keras version: 2.2.4
xamyzhao commented 5 years ago

@steermomo thanks for catching this. You should check the definition of color_delta_unet_model() in networks.py -- the latest version has the model outputs in the order outputs=[transformed_out, color_delta, x_seg], so it's possible that you might need to update line 576 in segmenter_model.py instead. I'll fix this in the main branch.

zhengkang86 commented 4 years ago

@steermomo Hi, I think you need to remove the [::-1] when defining the color_aug_model, like this:

self.color_aug_model = Model(
                        inputs=self.color_aug_model.inputs,
                        outputs=[
                            self.color_aug_model.outputs[0],
                            self.color_aug_model.get_layer('add_color_delta').output,
                            self.color_aug_model.outputs[2]],
                        name='color_model_wrapper')

Because the color_aug_model needs the 4-th input to output the augmented image, which is transformed both in spatial space and appearance.

And keep this line in the segmenter_model.py as it is:

colored_vol, color_delta, _ = self.color_aug_model.predict([source_X, X_colortgt_src, source_contours, flow])

Because the color_aug_model in color_delta_unet_model in networks.py ouputs [transformed_out, color_delta, x_seg] in this order.