rs-station / careless

Merge X-ray diffraction data with Wilson's priors, variational inference, and metadata
MIT License
16 stars 6 forks source link

adding double wilson flags leads to an error in version 0.4.1 #150

Closed hkwang closed 7 months ago

hkwang commented 8 months ago

I get the following error when running a double-wilson careless job, which ran on careless 0.3.8 in the past, and now does not run on careless 0.4.1. That script also runs fine when I removed the double-wilson flags:

  --double-wilson-parents=None,0,0,0,0\
  --double-wilson-r=0,0.993,0.993,0.993,0.993\ 
Careless version 0.4.1
Training:   0%|                                                                                                            | 0/9000 [00:10<?, ?it/s]
Traceback (most recent call last):
  File "/n/home01/hwang6/mambaforge/envs/tf2.12_cuda11/bin/careless", line 8, in <module>
    sys.exit(main())
  File "/n/home01/hwang6/mambaforge/envs/tf2.12_cuda11/lib/python3.10/site-packages/careless/careless.py", line 9, in main
    run_careless(parser)
  File "/n/home01/hwang6/mambaforge/envs/tf2.12_cuda11/lib/python3.10/site-packages/careless/careless.py", line 53, in run_careless
    history = model.train_model(
  File "/n/home01/hwang6/mambaforge/envs/tf2.12_cuda11/lib/python3.10/site-packages/careless/models/merging/variational.py", line 250, in train_model
    _history = train_step((self, data))
  File "/n/home01/hwang6/.local/lib/python3.10/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/n/home01/hwang6/.local/lib/python3.10/site-packages/tensorflow/python/eager/execute.py", line 52, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:

Detected at node 'variational_merging_model/GatherV2_1' defined at (most recent call last):
    File "/n/home01/hwang6/mambaforge/envs/tf2.12_cuda11/bin/careless", line 8, in <module>
      sys.exit(main())
    File "/n/home01/hwang6/mambaforge/envs/tf2.12_cuda11/lib/python3.10/site-packages/careless/careless.py", line 9, in main
      run_careless(parser)
    File "/n/home01/hwang6/mambaforge/envs/tf2.12_cuda11/lib/python3.10/site-packages/careless/careless.py", line 53, in run_careless
      history = model.train_model(
    File "/n/home01/hwang6/mambaforge/envs/tf2.12_cuda11/lib/python3.10/site-packages/careless/models/merging/variational.py", line 250, in train_model
      _history = train_step((self, data))
    File "/n/home01/hwang6/mambaforge/envs/tf2.12_cuda11/lib/python3.10/site-packages/careless/models/merging/variational.py", line 230, in train_step
      history = model.train_step_with_gradient_norm((data,))
    File "/n/home01/hwang6/mambaforge/envs/tf2.12_cuda11/lib/python3.10/site-packages/careless/models/merging/variational.py", line 195, in train_step_with_gradient_norm
      y_pred = self(x, training=True)
    File "/n/home01/hwang6/.local/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/n/home01/hwang6/.local/lib/python3.10/site-packages/keras/engine/training.py", line 558, in __call__
      return super().__call__(*args, **kwargs)
    File "/n/home01/hwang6/.local/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/n/home01/hwang6/.local/lib/python3.10/site-packages/keras/engine/base_layer.py", line 1145, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/n/home01/hwang6/.local/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "/n/home01/hwang6/mambaforge/envs/tf2.12_cuda11/lib/python3.10/site-packages/careless/models/merging/variational.py", line 173, in call
      if self.kl_weight is None:
    File "/n/home01/hwang6/mambaforge/envs/tf2.12_cuda11/lib/python3.10/site-packages/careless/models/merging/variational.py", line 174, in call
      self.add_kl_div(self.surrogate_posterior, self.prior, z_f, name='F KLDiv', reduction='sum')
    File "/n/home01/hwang6/mambaforge/envs/tf2.12_cuda11/lib/python3.10/site-packages/careless/models/merging/variational.py", line 129, in add_kl_div
      kl_div = posterior.log_prob(samples) - prior.log_prob(samples)
    File "/n/home01/hwang6/mambaforge/envs/tf2.12_cuda11/lib/python3.10/site-packages/careless/models/priors/wilson.py", line 139, in log_prob
      z_parent = tf.gather(z, self.reflids, axis=-1)
Node: 'variational_merging_model/GatherV2_1'
indices[121201] = -1 is not in [0, 161578)
         [[{{node variational_merging_model/GatherV2_1}}]] [Op:__inference_train_step_12814]
Job 15651338:careless finished on holy7c24201.rc.fas.harvard.edu in 0 minutes.
kmdalton commented 8 months ago

@hkwang , DM me with the location of the input files and careless script.

kmdalton commented 7 months ago

This is a bug related to the behavior of tf.gather and that affects models trained on CPU using harmonic deconvolution and the double-wilson prior. It is related to the behavior of tf.gather when indices are out of bound. In the log_prob calculation, the dw prior uses gather to locate samples from the "parent" of each node. it uses an array, self.reflids, to cache the indices for this lookup. if the node has no parent or a particular reflection is observed in the child but not the parent, this array has the value -1. in tf.gather, indices are always positive and zero-indexed. so, -1 is technically not a valid index. However, tf.gather has different behavior on CPU and GPU. On GPU, gathering with index -1 just returns a 0 which is the desired outcome anyway. On CPU, tf tries to validate the indices and will raise an error leading to a crash.

kmdalton commented 7 months ago

From the tf docs:

Caution: On CPU, if an out of bound index is found, an error is raised. On GPU, if an out of bound index is found, a 0 is stored in the corresponding output value.