neuronets / nobrainer

A framework for developing neural network models for 3D image processing.
Other
159 stars 45 forks source link

SynthSeg Brain Generation fails when randomise_res is set to True #338

Closed hvgazula closed 3 months ago

hvgazula commented 5 months ago

what were you trying to do? generating a sample brain with synthseg (using default parameters)

what did you expect will happen? generates a sample without raising any error

what actually happened?

Traceback (most recent call last):
  File "/net/vast-storage/scratch/vast/gablab/hgazula/SynthSeg/scripts/tutorials/1-generation_visualisation.py", line 28, in <module>
    im, lab = brain_generator.generate_brain()
  File "/om2/user/hgazula/SynthSeg/SynthSeg/brain_generator.py", line 324, in generate_brain
    (image, labels) = next(self.brain_generator)
  File "/om2/user/hgazula/SynthSeg/SynthSeg/brain_generator.py", line 319, in _build_brain_generator
    [image, labels] = self.labels_to_image_model.predict(model_inputs)
  File "/om/user/hgazula/venvs/nobrainer_satra/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/tmp/__autograph_generated_filescynv096.py", line 15, in tf__predict_function
    retval_ = ag__.converted_call(ag__.ld(step_function), (ag__.ld(self), ag__.ld(iterator)), None, fscope)
  File "/tmp/__autograph_generated_filedv9fe6xu.py", line 38, in tf__call
    ag__.if_stmt(ag__.not_(ag__.ld(self).add_batchsize), if_body, else_body, get_state, set_state, ('mask', 'self.min_res_tens', 'shape'), 3)
  File "/tmp/__autograph_generated_filedv9fe6xu.py", line 28, in else_body
    ag__.ld(self).min_res_tens = ag__.converted_call(ag__.ld(tf).tile, (ag__.converted_call(ag__.ld(tf).expand_dims, (ag__.ld(self).min_res_tens, 0), None, fscope), ag__.ld(tile_shape)), None, fscope)
ValueError: in user code:

    File "/om/user/hgazula/venvs/nobrainer_satra/lib/python3.10/site-packages/keras/src/engine/training.py", line 2416, in predict_function  *
        return step_function(self, iterator)
    File "/om/user/hgazula/venvs/nobrainer_satra/lib/python3.10/site-packages/keras/src/engine/training.py", line 2401, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/om/user/hgazula/venvs/nobrainer_satra/lib/python3.10/site-packages/keras/src/engine/training.py", line 2389, in run_step  **
        outputs = model.predict_step(data)
    File "/om/user/hgazula/venvs/nobrainer_satra/lib/python3.10/site-packages/keras/src/engine/training.py", line 2357, in predict_step
        return self(x, training=False)
    File "/om/user/hgazula/venvs/nobrainer_satra/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler
        raise e.with_traceback(filtered_tb) from None
    File "/tmp/__autograph_generated_filedv9fe6xu.py", line 38, in tf__call
        ag__.if_stmt(ag__.not_(ag__.ld(self).add_batchsize), if_body, else_body, get_state, set_state, ('mask', 'self.min_res_tens', 'shape'), 3)
    File "/tmp/__autograph_generated_filedv9fe6xu.py", line 28, in else_body
        ag__.ld(self).min_res_tens = ag__.converted_call(ag__.ld(tf).tile, (ag__.converted_call(ag__.ld(tf).expand_dims, (ag__.ld(self).min_res_tens, 0), None, fscope), ag__.ld(tile_shape)), None, fscope)

    ValueError: Exception encountered when calling layer 'sample_resolution' (type SampleResolution).

    in user code:

        File "/om2/user/hgazula/SynthSeg/ext/lab2im/layers.py", line 608, in call  *
            self.min_res_tens = tf.tile(tf.expand_dims(self.min_res_tens, 0), tile_shape)

        ValueError: Shape must be rank 3 but is rank 2 for '{{node model/sample_resolution/Tile}} = Tile[T=DT_FLOAT, Tmultiples=DT_INT32](model/sample_resolution/ExpandDims, model/sample_resolution/concat)' with input shapes: [1,?,3], [2].

    Call arguments received by layer 'sample_resolution' (type SampleResolution):
      • inputs=tf.Tensor(shape=(None, 113, 1), dtype=float32)
      • kwargs={'training': 'False'}

Can you replicate the behavior? If yes, how? see github.com/nobrainer_training_scripts/1.2.0/scripts/train/synthseg.py

hvgazula commented 3 months ago

@spikedoanz Gathering some info here to determine the plan of action- Can you please remind me why did you choose to use the synthseg in nobrainer (which is 2.15) rather than the original synthseg (2.2)? My guess about this error is- it is an artifact of the newer tf version which will take some time to figure out.

CC: @sergeyplis

hvgazula commented 3 months ago

Also, partially (if not fully 🤷‍♂️ ) my fault for not (unintentionally) informing you about the limitation of synthseg in nobrainer that it failed with randomise_res=True.

spikedoanz commented 3 months ago

We decided to use the nobrainer version of SynthSeg because it's faster than the original Synthseg implementation by orders of magnitude (on A40, original synthseg boots in 20 minutes, while the nobrainer version boots in just 7 seconds; and once it starts generating, the nobrainer version is twice as fast)

But, for now we'll try to move back to the original Synthseg where we need randomise_res. Thanks so much for the help Harsha!

hvgazula commented 3 months ago

Okay. that 20 min vs 7 sec gap is enough to take this up seriously. Please use the old one as a stopgap while I look into this.

hvgazula commented 3 months ago

In SampleResolution.call(...) replaced self.min_res_tens = tf.tile(tf.expand_dims(self.min_res_tens, 0), tile_shape) with self.min_res_tens_tiled = tf.tile(tf.expand_dims(self.min_res_tens, 0), tile_shape) and replaced every subsequent self.min_res_tens with self.min_res_tens_tiled.

Note: In hindsight, it wasn't a good idea to reuse/overwrite a node in the graph, rather creating a new node is desirable.

hvgazula commented 3 months ago

@spikedoanz I "think" I fixed this issue (without extensive testing). I'd appreciate it if you could test it and let me know if any issues arise.

CC: @sergeyplis