CosmoStat / wf-psf

Data-driven wavefront-based PSF modelling framework.
MIT License
19 stars 9 forks source link

Error in "predict_opd" function with Random Positions #137

Closed nadamoukaddem closed 1 month ago

nadamoukaddem commented 1 month ago

I am having an error with the function predict_opd, which predicts the OPD at certain positions. Everything works fine with the positions of the training and testing stars. The error appears when I run it on a random position. The function should predict the OPD at any given position without an error. I am running this function on Jupyter Notebook on a Slurm node using a TensorFlow 2.9.1 (Python 3.10) kernel. The error I'm encountering is the following:

`2024-07-16 14:42:16.383127: W tensorflow/core/framework/op_kernel.cc:1745] OP_REQUIRES failed at strided_slice_op.cc:102 : INVALID_ARGUMENT: slice index 0 of dimension 0 out of bounds.


InvalidArgumentError Traceback (most recent call last) Input In [10], in <cell line: 7>() 1 ###### input_0 = tf.convert_to_tensor(np.array([[93.2313 , 247.74956]]),dtype=tf.float32) 2 # psf_model_after.predict_opd(input_0) 3 #psf_model_before.predict_zernikes(tf.convert_to_tensor(np.array([[595.4554 , 376.15588 ]]),dtype=tf.float32)) 4 #psf_model_before.predict_zernikes(tf.convert_to_tensor(np.array([[596.4554 , 1 ]]),dtype=tf.float32)) 6 input_0 = tf.convert_to_tensor(np.array([[773.59656 , 261.25967 ]]), dtype=tf.float32) ----> 7 psf_model_before.predict_opd(input_0)

File /gpfswork/rech/ynx/uch76qv/.local/lib/python3.10/site-packages/wf_psf/psf_models/psf_model_physical_polychromatic.py:508, in TFPhysicalPolychromaticField.predict_opd(self, input_positions) 494 """Predict the OPD at some positions. 495 496 Parameters (...) 505 506 """ 507 # Predict zernikes from parametric model and physical layer --> 508 zks_coeffs = self.predict_zernikes(input_positions) 509 # Propagate to obtain the OPD 510 param_opd_maps = self.tf_zernike_OPD(zks_coeffs)

File /gpfswork/rech/ynx/uch76qv/.local/lib/python3.10/site-packages/wf_psf/psf_models/psf_model_physical_polychromatic.py:581, in TFPhysicalPolychromaticField.predict_zernikes(self, input_positions) 578 zernike_params = self.tf_poly_Z_field(input_positions) 580 # Calculate the prediction from the physical layer --> 581 physical_layer_prediction = self.tf_physical_layer.predict(input_positions) 583 # Pad and sum the Zernike coefficients 584 padded_zernike_params, padded_physical_layer_prediction = self.pad_zernikes( 585 zernike_params, physical_layer_prediction 586 )

File /gpfswork/rech/ynx/uch76qv/.local/lib/python3.10/site-packages/wf_psf/psf_models/tf_layers.py:968, in TFPhysicalLayer.call(self, positions) 965 return tf.where(tf.equal(self.obs_pos, idx_pos))[0, 0] 967 # Calculate the indices of the input batch --> 968 indices = tf.map_fn(calc_index, positions, fn_output_signature=tf.int64) 969 # Recover the prior zernikes from the batch indexes 970 batch_zks = tf.gather(self.zks_prior, indices=indices, axis=0, batch_dims=0)

File /gpfswork/rech/ynx/uch76qv/.local/lib/python3.10/site-packages/tensorflow/python/util/deprecation.py:629, in deprecated_arg_values..deprecated_wrapper..new_func(*args, *kwargs) 622 _PRINTED_WARNING[(func, arg_name)] = True 623 logging.warning( 624 'From %s: calling %s (from %s) with %s=%s is deprecated and ' 625 'will be removed %s.\nInstructions for updating:\n%s', 626 _call_location(), decorator_utils.get_qualified_name(func), 627 func.module, arg_name, arg_value, 'in a future version' 628 if date is None else ('after %s' % date), instructions) --> 629 return func(args, **kwargs)

File /gpfswork/rech/ynx/uch76qv/.local/lib/python3.10/site-packages/tensorflow/python/util/deprecation.py:561, in deprecated_args..deprecated_wrapper..new_func(*args, *kwargs) 553 _PRINTED_WARNING[(func, arg_name)] = True 554 logging.warning( 555 'From %s: calling %s (from %s) with %s is deprecated and will ' 556 'be removed %s.\nInstructions for updating:\n%s', (...) 559 'in a future version' if date is None else ('after %s' % date), 560 instructions) --> 561 return func(args, **kwargs)

File /gpfswork/rech/ynx/uch76qv/.local/lib/python3.10/site-packages/tensorflow/python/ops/map_fn.py:637, in map_fn_v2(fn, elems, dtype, parallel_iterations, back_prop, swap_memory, infer_shape, name, fn_output_signature) 635 if fn_output_signature is None: 636 fn_output_signature = dtype --> 637 return map_fn( 638 fn=fn, 639 elems=elems, 640 fn_output_signature=fn_output_signature, 641 parallel_iterations=parallel_iterations, 642 back_prop=back_prop, 643 swap_memory=swap_memory, 644 infer_shape=infer_shape, 645 name=name)

File /gpfswork/rech/ynx/uch76qv/.local/lib/python3.10/site-packages/tensorflow/python/util/deprecation.py:561, in deprecated_args..deprecated_wrapper..new_func(*args, *kwargs) 553 _PRINTED_WARNING[(func, arg_name)] = True 554 logging.warning( 555 'From %s: calling %s (from %s) with %s is deprecated and will ' 556 'be removed %s.\nInstructions for updating:\n%s', (...) 559 'in a future version' if date is None else ('after %s' % date), 560 instructions) --> 561 return func(args, **kwargs)

File /gpfswork/rech/ynx/uch76qv/.local/lib/python3.10/site-packages/tensorflow/python/ops/map_fn.py:495, in map_fn(fn, elems, dtype, parallel_iterations, back_prop, swap_memory, infer_shape, name, fn_output_signature) 490 tas = [ 491 ta.write(i, value) for (ta, value) in zip(tas, result_valuebatchable) 492 ] 493 return (i + 1, tas) --> 495 , r_a = control_flow_ops.whileloop( 496 lambda i, : i < n, 497 compute, (i, result_batchable_ta), 498 parallel_iterations=parallel_iterations, 499 back_prop=back_prop, 500 swap_memory=swap_memory, 501 maximum_iterations=n) 502 result_batchable = [r.stack() for r in r_a] 504 # Update each output tensor w/ static shape info about the outer dimension.

File /gpfswork/rech/ynx/uch76qv/.local/lib/python3.10/site-packages/tensorflow/python/ops/control_flow_ops.py:2754, in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure) 2751 loop_var_structure = nest.map_structure(type_spec.type_spec_from_value, 2752 list(loop_vars)) 2753 while cond(loop_vars): -> 2754 loop_vars = body(loop_vars) 2755 if try_to_pack and not isinstance(loop_vars, (list, _basetuple)): 2756 packed = True

File /gpfswork/rech/ynx/uch76qv/.local/lib/python3.10/site-packages/tensorflow/python/ops/control_flow_ops.py:2745, in while_loop..(i, lv) 2742 loop_vars = (counter, loop_vars) 2743 cond = lambda i, lv: ( # pylint: disable=g-long-lambda 2744 math_ops.logical_and(i < maximum_iterations, orig_cond(lv))) -> 2745 body = lambda i, lv: (i + 1, orig_body(lv)) 2746 try_to_pack = False 2748 if executing_eagerly:

File /gpfswork/rech/ynx/uch76qv/.local/lib/python3.10/site-packages/tensorflow/python/ops/map_fn.py:485, in map_fn..compute(i, tas) 483 ag_ctx = autograph_ctx.control_status_ctx() 484 autographed_fn = autograph.tf_convert(fn, ag_ctx) --> 485 result_value = autographed_fn(elems_value) 486 nest.assert_same_structure(fn_output_signature or elems, result_value) 487 result_value_flat = nest.flatten(result_value)

File /gpfswork/rech/ynx/uch76qv/.local/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py:692, in convert..decorator..wrapper(*args, **kwargs) 690 except Exception as e: # pylint:disable=broad-except 691 if hasattr(e, 'ag_error_metadata'): --> 692 raise e.ag_error_metadata.to_exception(e) 693 else: 694 raise

InvalidArgumentError: in user code:

File "[/gpfswork/rech/ynx/uch76qv/.local/lib/python3.10/site-packages/wf_psf/psf_models/tf_layers.py", line 965](https://jupyterhub.idris.fr/gpfswork/rech/ynx/uch76qv/.local/lib/python3.10/site-packages/wf_psf/psf_models/tf_layers.py#line=964), in calc_index  *
    return tf.where(tf.equal(self.obs_pos, idx_pos))[0, 0]

InvalidArgumentError: slice index 0 of dimension 0 out of bounds. [Op:StridedSlice] name: strided_slice[/](https://jupyterhub.idris.fr/)`
nadamoukaddem commented 1 month ago

Ezequiel reproduced this error and solved it by setting the centroid correction to false in the training configuration file before applying the predict_opd function.