While testing different momentum distributions for the PreconditionedHamiltonianMonteCarlo kernel, it appeared that it is impossible to run with XLA a chain with the momentum distribution MultivariateNormalDiagPlusLowRank.
Here is a simple code that reproduces the error
#!/usr/bin/env python3
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability import experimental as tfpe
from tensorflow_probability import distributions as tfd
dtype = "float32"
##############################
# Building a covariance matrix
##############################
M = np.clip(np.random.normal(0, 0.25, (100, 100)), -0.9, 0.9)
M = 0.5 * (M + M.T)
M += np.diag(np.ones(100) - np.diag(M))
# We make sure it's positive definite
eigvals = np.linalg.eigvals(M)
if np.min(eigvals) < 0:
M -= 1.1 * np.diag(np.ones(100) * np.min(eigvals))
M = tf.constant(M, dtype)
###############################################
# Creating the different momentum distributions
###############################################
L = tf.linalg.cholesky(M)
# Full covariance matrix
true_dist = tfd.MultivariateNormalTriL(
loc=tf.zeros(100, dtype), scale_tril=L, validate_args=True
)
# Just the diagonal
diag_dist = tfd.MultivariateNormalDiag(
loc=tf.zeros(100, dtype), scale_diag=tf.linalg.diag_part(M), validate_args=True
)
# A low rank approximation
# The low rank approximation is on the scale (M = L @ L^T and L = D + U @ U^T)
U = tfp.math.pivoted_cholesky(L, 10)
L_approx = U @ tf.transpose(U)
D = tf.linalg.diag_part(L) - tf.linalg.diag_part(L_approx) # correcting for the diagonal
lr_dist = tfd.MultivariateNormalDiagPlusLowRank(
loc=tf.zeros(100, dtype), scale_diag=D, scale_perturb_factor=U, validate_args=True
)
##################################################
# Run the kernel for a given momentum distribution
##################################################
# @tf.function(jit_compile=True)
@tf.function(experimental_compile=True)
def _run(dist):
chain, _ = tfp.mcmc.sample_chain(
num_results=100,
num_burnin_steps=100,
current_state=[tf.zeros(100, dtype)],
kernel=tfp.mcmc.DualAveragingStepSizeAdaptation(
tfpe.mcmc.PreconditionedHamiltonianMonteCarlo(
target_log_prob_fn=true_dist.log_prob,
step_size=0.1,
num_leapfrog_steps=8,
momentum_distribution=dist,
store_parameters_in_results=True,
),
num_adaptation_steps=int(100 * 0.8),
),
trace_fn=(lambda a, b: tf.constant(0.0, dtype)),
)
ess = tfp.mcmc.effective_sample_size(chain)[0]
return (chain, ess)
#This runs well
print("Chain with true mass matrix")
chain_true_dist, ess_true_dist = _run(true_dist)
print("Chain with diagonal mass matrix")
chain_diag_dist, ess_diag_dist = _run(diag_dist)
# This runs with jit_compile=False
# It fails with jit compile=True
print("Chain with low rank mass matrix")
chain_lr_dist, ess_lr_dist = _run(lr_dist)
The output is incredibly long, but here is an extract:
InvalidArgumentError Traceback (most recent call last)
~/scripts/tests/jit_det_bug_report.py in <module>
70 chain_diag_dist, ess_diag_dist = _run(diag_dist)
71 print("Chain with low rank mass matrix")
---> 72 chain_lr_dist, ess_lr_dist = _run(lr_dist)
~/.conda/envs/perso/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
887
888 with OptionalXlaContext(self._jit_compile):
--> 889 result = self._call(*args, **kwds)
890
891 new_tracing_count = self.experimental_get_tracing_count()
~/.conda/envs/perso/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
922 # In this case we have not created variables on the first call. So we can
923 # run the first trace but we should fail if variables are created.
--> 924 results = self._stateful_fn(*args, **kwds)
925 if self._created_variables:
926 raise ValueError("Creating variables on a non-first call to a function"
~/.conda/envs/perso/lib/python3.8/site-packages/tensorflow/python/eager/function.py in __call__(self, *args, **kwargs)
3021 (graph_function,
3022 filtered_flat_args) = self._maybe_define_function(args, kwargs)
-> 3023 return graph_function._call_flat(
3024 filtered_flat_args, captured_inputs=graph_function.captured_inputs) # pylint: disable=protected-access
3025
~/.conda/envs/perso/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
1958 and executing_eagerly):
1959 # No tape is watching; skip to running the function.
-> 1960 return self._build_call_outputs(self._inference_function.call(
1961 ctx, args, cancellation_manager=cancellation_manager))
1962 forward_backward = self._select_forward_and_backward_functions(
~/.conda/envs/perso/lib/python3.8/site-packages/tensorflow/python/eager/function.py in call(self, ctx, args, cancellation_manager)
589 with _InterpolateFunctionError(self):
590 if cancellation_manager is None:
--> 591 outputs = execute.execute(
592 str(self.signature.name),
593 num_outputs=self._num_outputs,
~/.conda/envs/perso/lib/python3.8/site-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
57 try:
58 ctx.ensure_initialized()
---> 59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
60 inputs, attrs, num_outputs)
61 except core._NotOkStatusException as e:
InvalidArgumentError: Detected unsupported operations when trying to compile graph mcmc_sample_chain_trace_scan_while_smart_for_loop_while_body_5806_const_0[] on XLA_GPU_JIT: MatrixDeterminant (No registered 'MatrixDeterminant' OpKernel for XLA_GPU_JIT devices compatible with node {{node mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/mh_one_step/phmc_kernel_one_step/compute_log_acceptance_correction/mcmc_sample_chain_trace_scan_while_smart_for_loop_while_dual_averaging_step_size_adaptation___init____one_step_mh_one_step_phmc_kernel_one_step_JointDistributionSequential/log_prob/mcmc_sample_chain_trace_scan_while_smart_for_loop_while_dual_averaging_step_size_adaptation___init____one_step_mh_one_step_phmc_kernel_one_step_BatchBroadcastmcmc_sample_chain_trace_scan_while_smart_for_loop_while_BatchBroadcastMultivariateNormalDiagPlusLowRank_4/log_prob/mcmc_sample_chain_trace_scan_while_smart_for_loop_while_BatchBroadcastMultivariateNormalDiagPlusLowRank_4/log_prob/MultivariateNormalDiagPlusLowRank/log_prob/chain_of_shift_of_scale_matvec_linear_operator_2/inverse_log_det_jacobian/scale_matvec_linear_operator/inverse_log_det_jacobian/LinearOperatorLowRankUpdate/log_abs_det/MatrixDeterminant}}){{node mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/mh_one_step/phmc_kernel_one_step/compute_log_acceptance_correction/mcmc_sample_chain_trace_scan_while_smart_for_loop_while_dual_averaging_step_size_adaptation___init____one_step_mh_one_step_phmc_kernel_one_step_JointDistributionSequential/log_prob/mcmc_sample_chain_trace_scan_while_smart_for_loop_while_dual_averaging_step_size_adaptation___init____one_step_mh_one_step_phmc_kernel_one_step_BatchBroadcastmcmc_sample_chain_trace_scan_while_smart_for_loop_while_BatchBroadcastMultivariateNormalDiagPlusLowRank_4/log_prob/mcmc_sample_chain_trace_scan_while_smart_for_loop_while_BatchBroadcastMultivariateNormalDiagPlusLowRank_4/log_prob/MultivariateNormalDiagPlusLowRank/log_prob/chain_of_shift_of_scale_matvec_linear_operator_2/inverse_log_det_jacobian/scale_matvec_linear_operator/inverse_log_det_jacobian/LinearOperatorLowRankUpdate/log_abs_det/MatrixDeterminant}}
@SiegeLordEx Is it possible to take a look at this? The underlying issue doesn't trigger any more (since I believe there is an XLA registration of the op), but it seems like I get an error from tfp.mcmc.
While testing different momentum distributions for the
PreconditionedHamiltonianMonteCarlo
kernel, it appeared that it is impossible to run with XLA a chain with the momentum distributionMultivariateNormalDiagPlusLowRank
.Here is a simple code that reproduces the error
The output is incredibly long, but here is an extract:
My packages are: