juglab / n2v

This is the implementation of Noise2Void training.
Other
385 stars 107 forks source link

M1 Mac: Graph execution error on 3D data #133

Closed thawn closed 1 year ago

thawn commented 1 year ago

Note: I moved this issue over from napari-n2v, because I realized that this also occurs in n2v without napari.

I am trying to train a n2v model for 3d data following your 3D example notebook

If I leave the data as 2D (SXY), it works fine (but the result is suboptimal)

However, when I enable 3D I get and error: InvalidArgumentError: Graph execution error

the key bit in the debug message seems to be:

INVALID_ARGUMENT:  input must be 4-dimensional[16,16,64,64,32]
 [[{{node model/batch_normalization/FusedBatchNormV3}}]]

I am on a M1 Max Mac. Versions of key packages: Python 3.9.15 n2v 0.3.2 tensorflow 2.10 tensorflow-metal 0.6

p.s. a colleague just confirmed that it works fine on his windows notebook, so it seems to be a mac-specific issue

click for full debug log 2023-01-05 18:47:08.790790: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz 2023-01-05 18:47:09.818930: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled. --------------------------------------------------------------------------- InvalidArgumentError Traceback (most recent call last) File /opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/threading.py:980, in Thread._bootstrap_inner(self=) 977 _sys.setprofile(_profile_hook) 979 try: --> 980 self.run() self = 981 except: 982 self._invoke_excepthook(self) File /opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/threading.py:917, in Thread.run(self=) 915 try: 916 if self._target: --> 917 self._target(*self._args, **self._kwargs) self = 918 finally: 919 # Avoid a refcycle if the thread is running a function with 920 # an argument that has a member that points to the thread. 921 del self._target, self._args, self._kwargs File /opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/napari_n2v/utils/training_worker.py:308, in train(model=N2V(n2v_3D): ZYXC → ZYXC ├─ Directory: /Users/ko...e=False, structN2Vmask=None, probabilistic=False), X_patches= (3195, 16, 64, 64, 1) uint16, X_val_patches= (5, 16, 64, 64, 1) uint16, updater=) 306 def train(model, X_patches, X_val_patches, updater): 307 try: --> 308 model.train(X_patches, X_val_patches) X_patches = (3195, 16, 64, 64, 1) uint16 X_val_patches = (5, 16, 64, 64, 1) uint16 model = N2V(n2v_3D): ZYXC → ZYXC ├─ Directory: /Users/korten/.napari/N2V/models/n2v_3D └─ N2VConfig(means=['39.119583306588666'], stds=['5.021673256287464'], n_dim=3, axes='ZYXC', n_channel_in=1, n_channel_out=1, unet_residual=False, unet_n_depth=2, unet_kern_size=3, unet_n_first=32, unet_last_activation='linear', unet_input_shape=(None, None, None, 1), train_loss='mse', train_epochs=30, train_steps_per_epoch=200, train_learning_rate=0.0004, train_batch_size=16, train_tensorboard=True, train_checkpoint='weights_best.h5', train_reduce_lr={'factor': 0.5, 'patience': 10, 'verbose': True}, batch_norm=True, n2v_perc_pix=0.198, n2v_patch_shape=[16, 64, 64], n2v_manipulator='uniform_withCP', n2v_neighborhood_radius=5, single_net_per_channel=True, blurpool=False, skip_skipone=False, structN2Vmask=None, probabilistic=False) 310 except AssertionError as e: 311 # TODO there's probably a lot more than that 312 msg = 'AssertionError can be caused by n2v masked pixel % being too low' File /opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/n2v/models/n2v_standard.py:265, in N2V.train(self=N2V(n2v_3D): ZYXC → ZYXC ├─ Directory: /Users/ko...e=False, structN2Vmask=None, probabilistic=False), X= (3195, 16, 64, 64, 1) float32, validation_X= (5, 16, 64, 64, 1) float32, epochs=30, steps_per_epoch=200) 257 n2v_utils.manipulate_val_data(validation_X, validation_Y, 258 perc_pix=self.config.n2v_perc_pix, 259 shape=val_patch_shape, 260 value_manipulation=manipulator) 261 self.callbacks.append(CARETensorBoardImage(model=self.keras_model, data=(validation_X, validation_X), 262 log_dir=str(self.logdir / 'logs' / 'images'), 263 n_images=3, prob_out=False)) --> 265 history = self.keras_model.fit(iter(training_data), validation_data=(validation_X, validation_Y), validation_X = (5, 16, 64, 64, 1) float32 training_data = validation_Y = (5, 16, 64, 64, 2) float32 self = N2V(n2v_3D): ZYXC → ZYXC ├─ Directory: /Users/korten/.napari/N2V/models/n2v_3D └─ N2VConfig(means=['39.119583306588666'], stds=['5.021673256287464'], n_dim=3, axes='ZYXC', n_channel_in=1, n_channel_out=1, unet_residual=False, unet_n_depth=2, unet_kern_size=3, unet_n_first=32, unet_last_activation='linear', unet_input_shape=(None, None, None, 1), train_loss='mse', train_epochs=30, train_steps_per_epoch=200, train_learning_rate=0.0004, train_batch_size=16, train_tensorboard=True, train_checkpoint='weights_best.h5', train_reduce_lr={'factor': 0.5, 'patience': 10, 'verbose': True}, batch_norm=True, n2v_perc_pix=0.198, n2v_patch_shape=[16, 64, 64], n2v_manipulator='uniform_withCP', n2v_neighborhood_radius=5, single_net_per_channel=True, blurpool=False, skip_skipone=False, structN2Vmask=None, probabilistic=False) self.keras_model = (validation_X, validation_Y) = ( (5, 16, 64, 64, 1) float32, (5, 16, 64, 64, 2) float32) epochs = 30 steps_per_epoch = 200 self.callbacks = [, , , , , , ] 266 epochs=epochs, steps_per_epoch=steps_per_epoch, 267 callbacks=self.callbacks, verbose=1) 269 if self.basedir is not None: 270 self.keras_model.save_weights(str(self.logdir / 'weights_last.h5')) File /opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/utils/traceback_utils.py:70, in filter_traceback..error_handler(*args=(, ), **kwargs={'callbacks': [, , , , , , ], 'epochs': 30, 'steps_per_epoch': 200, 'validation_data': ( (5, 16, 64, 64, 1) float32, (5, 16, 64, 64, 2) float32), 'verbose': 1}) 67 filtered_tb = _process_traceback_frames(e.__traceback__) 68 # To get the full stack trace, call: 69 # `tf.debugging.disable_traceback_filtering()` ---> 70 raise e.with_traceback(filtered_tb) from None 71 finally: 72 del filtered_tb File /opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/tensorflow/python/eager/execute.py:54, in quick_execute(op_name='__inference_train_function_6391', num_outputs=1, inputs=[>, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, ...], attrs=('executor_type', '', 'config_proto', b'\n\x07\n\x03CPU\x10\x01\n\x07\n\x03GPU\x10\x012\x05*\x010J\x008\x01\x82\x01\x00'), ctx=, name=None) 52 try: 53 ctx.ensure_initialized() ---> 54 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name, device_name = '' ctx = op_name = '__inference_train_function_6391' pywrap_tfe = inputs = [>, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >, >] attrs = ('executor_type', '', 'config_proto', b'\n\x07\n\x03CPU\x10\x01\n\x07\n\x03GPU\x10\x012\x05*\x010J\x008\x01\x82\x01\x00') num_outputs = 1 55 inputs, attrs, num_outputs) 56 except core._NotOkStatusException as e: 57 if name is not None: InvalidArgumentError: Graph execution error: Detected at node 'model/batch_normalization/FusedBatchNormV3' defined at (most recent call last): File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/threading.py", line 937, in _bootstrap self._bootstrap_inner() File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/threading.py", line 980, in _bootstrap_inner self.run() File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/threading.py", line 917, in run self._target(*self._args, **self._kwargs) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/napari_n2v/utils/training_worker.py", line 308, in train model.train(X_patches, X_val_patches) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/n2v/models/n2v_standard.py", line 265, in train history = self.keras_model.fit(iter(training_data), validation_data=(validation_X, validation_Y), File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler return fn(*args, **kwargs) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/engine/training.py", line 1564, in fit tmp_logs = self.train_function(iterator) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/engine/training.py", line 1160, in train_function return step_function(self, iterator) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/engine/training.py", line 1146, in step_function outputs = model.distribute_strategy.run(run_step, args=(data,)) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/engine/training.py", line 1135, in run_step outputs = model.train_step(data) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/engine/training.py", line 993, in train_step y_pred = self(x, training=True) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler return fn(*args, **kwargs) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/engine/training.py", line 557, in __call__ return super().__call__(*args, **kwargs) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler return fn(*args, **kwargs) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1097, in __call__ outputs = call_fn(inputs, *args, **kwargs) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 96, in error_handler return fn(*args, **kwargs) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/engine/functional.py", line 510, in call return self._run_internal_graph(inputs, training=training, mask=mask) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/engine/functional.py", line 667, in _run_internal_graph outputs = node.layer(*args, **kwargs) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler return fn(*args, **kwargs) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1097, in __call__ outputs = call_fn(inputs, *args, **kwargs) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 96, in error_handler return fn(*args, **kwargs) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/layers/normalization/batch_normalization.py", line 850, in call outputs = self._fused_batch_norm(inputs, training=training) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/layers/normalization/batch_normalization.py", line 660, in _fused_batch_norm output, mean, variance = control_flow_util.smart_cond( File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/utils/control_flow_util.py", line 108, in smart_cond return tf.__internal__.smart_cond.smart_cond( File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/layers/normalization/batch_normalization.py", line 634, in _fused_batch_norm_training return tf.compat.v1.nn.fused_batch_norm( Node: 'model/batch_normalization/FusedBatchNormV3' Detected at node 'model/batch_normalization/FusedBatchNormV3' defined at (most recent call last): File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/threading.py", line 937, in _bootstrap self._bootstrap_inner() File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/threading.py", line 980, in _bootstrap_inner self.run() File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/threading.py", line 917, in run self._target(*self._args, **self._kwargs) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/napari_n2v/utils/training_worker.py", line 308, in train model.train(X_patches, X_val_patches) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/n2v/models/n2v_standard.py", line 265, in train history = self.keras_model.fit(iter(training_data), validation_data=(validation_X, validation_Y), File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler return fn(*args, **kwargs) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/engine/training.py", line 1564, in fit tmp_logs = self.train_function(iterator) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/engine/training.py", line 1160, in train_function return step_function(self, iterator) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/engine/training.py", line 1146, in step_function outputs = model.distribute_strategy.run(run_step, args=(data,)) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/engine/training.py", line 1135, in run_step outputs = model.train_step(data) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/engine/training.py", line 993, in train_step y_pred = self(x, training=True) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler return fn(*args, **kwargs) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/engine/training.py", line 557, in __call__ return super().__call__(*args, **kwargs) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler return fn(*args, **kwargs) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1097, in __call__ outputs = call_fn(inputs, *args, **kwargs) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 96, in error_handler return fn(*args, **kwargs) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/engine/functional.py", line 510, in call return self._run_internal_graph(inputs, training=training, mask=mask) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/engine/functional.py", line 667, in _run_internal_graph outputs = node.layer(*args, **kwargs) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler return fn(*args, **kwargs) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/engine/base_layer.py", line 1097, in __call__ outputs = call_fn(inputs, *args, **kwargs) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 96, in error_handler return fn(*args, **kwargs) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/layers/normalization/batch_normalization.py", line 850, in call outputs = self._fused_batch_norm(inputs, training=training) File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/layers/normalization/batch_normalization.py", line 660, in _fused_batch_norm output, mean, variance = control_flow_util.smart_cond( File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/utils/control_flow_util.py", line 108, in smart_cond return tf.__internal__.smart_cond.smart_cond( File "/opt/homebrew/Caskroom/mambaforge/base/envs/dbn-env/lib/python3.9/site-packages/keras/layers/normalization/batch_normalization.py", line 634, in _fused_batch_norm_training return tf.compat.v1.nn.fused_batch_norm( Node: 'model/batch_normalization/FusedBatchNormV3' 2 root error(s) found. (0) INVALID_ARGUMENT: input must be 4-dimensional[16,16,64,64,32] [[{{node model/batch_normalization/FusedBatchNormV3}}]] [[gradient_tape/model/batch_normalization_4/FusedBatchNormGradV3/_142]] (1) INVALID_ARGUMENT: input must be 4-dimensional[16,16,64,64,32] [[{{node model/batch_normalization/FusedBatchNormV3}}]] 0 successful operations. 0 derived errors ignored. [Op:__inference_train_function_6391]
thawn commented 1 year ago

n2v.ipynb.zip

here is a minimal example notebook

note that the shapes in this notebook are a little different, because I was following your 3D example

jdeschamps commented 1 year ago

Thanks for reporting this!

I have no experience with macOS, I will look around to see if someone has a machine we could use to debug this. I think next week me might be able to have a look.

thawn commented 1 year ago

let me know if I can be of help with the debugging.

veegalinova commented 1 year ago

Hi, it seems like an issue of M1 specifically. You can try to downgrade to tensorflow 2.9 and tensorflow-metal 0.5.0, or any other combination of versions specified on this page.