BBillot / SynthSeg

Contrast-agnostic segmentation of MRI scans
Apache License 2.0
358 stars 92 forks source link

Error running tutorial script (1-generation_visualisation.py) in newer versions of tensorflow #89

Closed hvgazula closed 1 month ago

hvgazula commented 1 month ago

Hi @BBillot

While the official recommended (according to the requirements.txt) TensorFlow is 2.2, I am trying to run SynthSeg in a more recent version of TensorFlow (constrained due to nobrainer) and it throws the following error. Admittedly, this is the case with almost every version from 2.3 to 2.15. 2.16 and above throws a different error which is for a different day.

Would you be kind enough to spare some time and help me fix this bug? I am happy to do any groundwork (environment setup, colab if you prefer) to make it easy for you. I'll be looking forward to your thoughts.

Note: The following error was thrown when running SynthSeg/scripts/tutorials/1-generation_visualisation.py in TF=2.15

Traceback (most recent call last):
  File "/net/vast-storage/scratch/vast/gablab/hgazula/SynthSeg/scripts/tutorials/1-generation_visualisation.py", line 28, in <module>
    im, lab = brain_generator.generate_brain()
  File "/om2/user/hgazula/SynthSeg/SynthSeg/brain_generator.py", line 324, in generate_brain
    (image, labels) = next(self.brain_generator)
  File "/om2/user/hgazula/SynthSeg/SynthSeg/brain_generator.py", line 319, in _build_brain_generator
    [image, labels] = self.labels_to_image_model.predict(model_inputs)
  File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/tmp/__autograph_generated_file9w0hokd_.py", line 15, in tf__predict_function
    retval_ = ag__.converted_call(ag__.ld(step_function), (ag__.ld(self), ag__.ld(iterator)), None, fscope)
  File "/tmp/__autograph_generated_filec3y9yjkg.py", line 38, in tf__call
    ag__.if_stmt(ag__.not_(ag__.ld(self).add_batchsize), if_body, else_body, get_state, set_state, ('mask', 'self.min_res_tens', 'shape'), 3)
  File "/tmp/__autograph_generated_filec3y9yjkg.py", line 28, in else_body
    ag__.ld(self).min_res_tens = ag__.converted_call(ag__.ld(tf).tile, (ag__.converted_call(ag__.ld(tf).expand_dims, (ag__.ld(self).min_res_tens, 0), None, fscope), ag__.ld(tile_shape)), None, fscope)
ValueError: in user code:

    File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/keras/src/engine/training.py", line 2440, in predict_function  *
        return step_function(self, iterator)
    File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/keras/src/engine/training.py", line 2425, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/keras/src/engine/training.py", line 2413, in run_step  **
        outputs = model.predict_step(data)
    File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/keras/src/engine/training.py", line 2381, in predict_step
        return self(x, training=False)
    File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler
        raise e.with_traceback(filtered_tb) from None
    File "/tmp/__autograph_generated_filec3y9yjkg.py", line 38, in tf__call
        ag__.if_stmt(ag__.not_(ag__.ld(self).add_batchsize), if_body, else_body, get_state, set_state, ('mask', 'self.min_res_tens', 'shape'), 3)
    File "/tmp/__autograph_generated_filec3y9yjkg.py", line 28, in else_body
        ag__.ld(self).min_res_tens = ag__.converted_call(ag__.ld(tf).tile, (ag__.converted_call(ag__.ld(tf).expand_dims, (ag__.ld(self).min_res_tens, 0), None, fscope), ag__.ld(tile_shape)), None, fscope)

    ValueError: Exception encountered when calling layer 'sample_resolution' (type SampleResolution).

    in user code:

        File "/om2/user/hgazula/SynthSeg/ext/lab2im/layers.py", line 608, in call  *
            self.min_res_tens = tf.tile(tf.expand_dims(self.min_res_tens, 0), tile_shape)

        ValueError: Shape must be rank 3 but is rank 2 for '{{node model/sample_resolution/Tile}} = Tile[T=DT_FLOAT, Tmultiples=DT_INT32](model/sample_resolution/ExpandDims, model/sample_resolution/concat)' with input shapes: [1,?,3], [2].

    Call arguments received by layer 'sample_resolution' (type SampleResolution):
      • inputs=tf.Tensor(shape=(None, 54, 1), dtype=float32)
      • kwargs={'training': 'False'}
hvgazula commented 1 month ago

While I managed to go past this step by modifying the build and call in SampleResolution by replacing

if input_shape:
    self.add_batchsize = True

self.min_res_tens = tf.convert_to_tensor(self.min_res, dtype='float32')

in build(...) with

  self.min_res_tens = tf.convert_to_tensor(self.min_res, dtype='float32')

  if input_shape:
      self.add_batchsize = True
      self.min_res_tens = tf.expand_dims(self.min_res_tens, 0)

and subsequently replacing self.min_res_tens = tf.tile(tf.expand_dims(self.min_res_tens, 0), tile_shape) in call(...) with self.min_res_tens = tf.tile(self.min_res_tens, tile_shape), it now throws a new error that goes along the lines of

Traceback (most recent call last):
  File "/net/vast-storage/scratch/vast/gablab/hgazula/SynthSeg/scripts/misc/synth_202408.py", line 291, in <module>
    [image, labels] = lab_to_im_model.predict(model_inputs)
  File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/tensorflow/python/eager/execute.py", line 53, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
TypeError: <tf.Tensor 'sample_resolution/Tile:0' shape=(None, 3) dtype=float32> is out of scope and cannot be used here. Use return values, explicit Python locals or TensorFlow collections to access it.
Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.

<tf.Tensor 'sample_resolution/Tile:0' shape=(None, 3) dtype=float32> was defined here:
    File "/net/vast-storage/scratch/vast/gablab/hgazula/SynthSeg/scripts/misc/synth_202408.py", line 247, in <module>
    File "/net/vast-storage/scratch/vast/gablab/hgazula/SynthSeg/scripts/misc/synth_202408.py", line 137, in labels_to_image_model
    File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler
    File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/keras/src/engine/base_layer.py", line 1063, in __call__
    File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/keras/src/engine/base_layer.py", line 2593, in _functional_construction_call
    File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/keras/src/engine/base_layer.py", line 2439, in _keras_tensor_symbolic_call
    File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/keras/src/engine/base_layer.py", line 2498, in _infer_output_signature
    File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler
    File "/om2/user/hgazula/SynthSeg/ext/lab2im/layers.py", line 715, in call
    File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1217, in if_stmt
    File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1270, in _py_if_stmt
    File "/om2/user/hgazula/SynthSeg/ext/lab2im/layers.py", line 733, in call
    File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/tensorflow/python/ops/gen_array_ops.py", line 12045, in tile
    File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/tensorflow/python/framework/op_def_library.py", line 796, in _apply_op_helper
    File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 670, in _create_op_internal
    File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/tensorflow/python/framework/ops.py", line 2652, in _create_op_internal
    File "/om2/user/hgazula/venvs/ss2.15/lib/python3.9/site-packages/tensorflow/python/framework/ops.py", line 1160, in from_node_def

The tensor <tf.Tensor 'sample_resolution/Tile:0' shape=(None, 3) dtype=float32> cannot be accessed from here, because it was defined in FuncGraph(name=sample_resolution_scratch_graph, id=47607497100160), which is out of scope.

I was wondering if you happen to have any ideas. I will also reach out to other Synth users to see if anyone's using the latest version of TF for their work.

hvgazula commented 1 month ago

Fixed. Thank you. See https://github.com/neuronets/nobrainer/issues/338#issuecomment-2282768104 for resolution. FYI, It is backward compatible as well.