Closed steermomo closed 5 years ago
@steermomo thanks for catching this. You should check the definition of color_delta_unet_model() in networks.py -- the latest version has the model outputs in the order outputs=[transformed_out, color_delta, x_seg]
, so it's possible that you might need to update line 576 in segmenter_model.py instead. I'll fix this in the main branch.
@steermomo Hi, I think you need to remove the [::-1]
when defining the color_aug_model
, like this:
self.color_aug_model = Model(
inputs=self.color_aug_model.inputs,
outputs=[
self.color_aug_model.outputs[0],
self.color_aug_model.get_layer('add_color_delta').output,
self.color_aug_model.outputs[2]],
name='color_model_wrapper')
Because the color_aug_model
needs the 4-th input to output the augmented image, which is transformed both in spatial space and appearance.
And keep this line in the segmenter_model.py
as it is:
colored_vol, color_delta, _ = self.color_aug_model.predict([source_X, X_colortgt_src, source_contours, flow])
Because the color_aug_model
in color_delta_unet_model
in networks.py
ouputs [transformed_out, color_delta, x_seg]
in this order.
Hi, Thank you for sharing the code. I run
to get the spatial transform models, edit
main.py
then runget the appearance/color transform model. I edit the code at line 265:276, and then train the segmentations network,
but get
Graph disconnected
error.After reading the code, I find that may be the problem in
segmenter_model.py
, line 417here the input
flow_srctotgt
is dropped, but theoutputs[0]
, which istransformed_out
, still connect toflow_srctotgt
bytransformed_out = SpatialTransformer(indexing='xy')([transformed_out, flow_srctotgt])
in line 292 at filenetworks.py
. So the program crashed.According to line 576 in file
segmenter_model.py
,Line 417 in file
segmenter_model.py
should beAfter change code here, I can train the segment network.
So, is this the correct way to fix that problem? Thanks.:thumbsup: #### My enviroment is: