tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.25k stars 1.1k forks source link

[Jit compile error] unsupported operation: MatrixDeterminant #1447

Open anvvalade opened 2 years ago

anvvalade commented 2 years ago

While testing different momentum distributions for the PreconditionedHamiltonianMonteCarlo kernel, it appeared that it is impossible to run with XLA a chain with the momentum distribution MultivariateNormalDiagPlusLowRank.

Here is a simple code that reproduces the error

#!/usr/bin/env python3

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability import experimental as tfpe
from tensorflow_probability import distributions as tfd

dtype = "float32"

##############################
# Building a covariance matrix
##############################

M = np.clip(np.random.normal(0, 0.25, (100, 100)), -0.9, 0.9)
M = 0.5 * (M + M.T)
M += np.diag(np.ones(100) - np.diag(M))

# We make sure it's positive definite
eigvals = np.linalg.eigvals(M)
if np.min(eigvals) < 0:
    M -= 1.1 * np.diag(np.ones(100) * np.min(eigvals))

M = tf.constant(M, dtype)

###############################################
# Creating the different momentum distributions
###############################################

L = tf.linalg.cholesky(M)

# Full covariance matrix
true_dist = tfd.MultivariateNormalTriL(
    loc=tf.zeros(100, dtype), scale_tril=L, validate_args=True
)

# Just the diagonal
diag_dist = tfd.MultivariateNormalDiag(
    loc=tf.zeros(100, dtype), scale_diag=tf.linalg.diag_part(M), validate_args=True
)

# A low rank approximation
# The low rank approximation is on the scale (M = L @ L^T and L = D + U @ U^T)
U = tfp.math.pivoted_cholesky(L, 10)
L_approx = U @ tf.transpose(U)
D = tf.linalg.diag_part(L) - tf.linalg.diag_part(L_approx) # correcting for the diagonal

lr_dist = tfd.MultivariateNormalDiagPlusLowRank(
    loc=tf.zeros(100, dtype), scale_diag=D, scale_perturb_factor=U, validate_args=True
)

##################################################
# Run the kernel for a given momentum distribution
##################################################

# @tf.function(jit_compile=True)
@tf.function(experimental_compile=True)
def _run(dist):
    chain, _ = tfp.mcmc.sample_chain(
        num_results=100,
        num_burnin_steps=100,
        current_state=[tf.zeros(100, dtype)],
        kernel=tfp.mcmc.DualAveragingStepSizeAdaptation(
            tfpe.mcmc.PreconditionedHamiltonianMonteCarlo(
                target_log_prob_fn=true_dist.log_prob,
                step_size=0.1,
                num_leapfrog_steps=8,
                momentum_distribution=dist,
                store_parameters_in_results=True,
            ),
            num_adaptation_steps=int(100 * 0.8),
        ),
        trace_fn=(lambda a, b: tf.constant(0.0, dtype)),
    )
    ess = tfp.mcmc.effective_sample_size(chain)[0]
    return (chain, ess)

#This runs well
print("Chain with true mass matrix")
chain_true_dist, ess_true_dist = _run(true_dist)
print("Chain with diagonal mass matrix")
chain_diag_dist, ess_diag_dist = _run(diag_dist)

# This runs with jit_compile=False
# It fails with jit compile=True
print("Chain with low rank mass matrix")
chain_lr_dist, ess_lr_dist = _run(lr_dist)

The output is incredibly long, but here is an extract:

InvalidArgumentError                      Traceback (most recent call last)
~/scripts/tests/jit_det_bug_report.py in <module>
     70 chain_diag_dist, ess_diag_dist = _run(diag_dist)
     71 print("Chain with low rank mass matrix")
---> 72 chain_lr_dist, ess_lr_dist = _run(lr_dist)

~/.conda/envs/perso/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    887
    888       with OptionalXlaContext(self._jit_compile):
--> 889         result = self._call(*args, **kwds)
    890
    891       new_tracing_count = self.experimental_get_tracing_count()

~/.conda/envs/perso/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    922       # In this case we have not created variables on the first call. So we can
    923       # run the first trace but we should fail if variables are created.
--> 924       results = self._stateful_fn(*args, **kwds)
    925       if self._created_variables:
    926         raise ValueError("Creating variables on a non-first call to a function"

~/.conda/envs/perso/lib/python3.8/site-packages/tensorflow/python/eager/function.py in __call__(self, *args, **kwargs)
   3021       (graph_function,
   3022        filtered_flat_args) = self._maybe_define_function(args, kwargs)
-> 3023     return graph_function._call_flat(
   3024         filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access
   3025

~/.conda/envs/perso/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1958         and executing_eagerly):
   1959       # No tape is watching; skip to running the function.
-> 1960       return self._build_call_outputs(self._inference_function.call(
   1961           ctx, args, cancellation_manager=cancellation_manager))
   1962     forward_backward = self._select_forward_and_backward_functions(

~/.conda/envs/perso/lib/python3.8/site-packages/tensorflow/python/eager/function.py in call(self, ctx, args, cancellation_manager)
    589       with _InterpolateFunctionError(self):
    590         if cancellation_manager is None:
--> 591           outputs = execute.execute(
    592               str(self.signature.name),
    593               num_outputs=self._num_outputs,

~/.conda/envs/perso/lib/python3.8/site-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     57   try:
     58     ctx.ensure_initialized()
---> 59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
     60                                         inputs, attrs, num_outputs)
     61   except core._NotOkStatusException as e:

InvalidArgumentError: Detected unsupported operations when trying to compile graph mcmc_sample_chain_trace_scan_while_smart_for_loop_while_body_5806_const_0[] on XLA_GPU_JIT: MatrixDeterminant (No registered 'MatrixDeterminant' OpKernel for XLA_GPU_JIT devices compatible with node {{node mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/mh_one_step/phmc_kernel_one_step/compute_log_acceptance_correction/mcmc_sample_chain_trace_scan_while_smart_for_loop_while_dual_averaging_step_size_adaptation___init____one_step_mh_one_step_phmc_kernel_one_step_JointDistributionSequential/log_prob/mcmc_sample_chain_trace_scan_while_smart_for_loop_while_dual_averaging_step_size_adaptation___init____one_step_mh_one_step_phmc_kernel_one_step_BatchBroadcastmcmc_sample_chain_trace_scan_while_smart_for_loop_while_BatchBroadcastMultivariateNormalDiagPlusLowRank_4/log_prob/mcmc_sample_chain_trace_scan_while_smart_for_loop_while_BatchBroadcastMultivariateNormalDiagPlusLowRank_4/log_prob/MultivariateNormalDiagPlusLowRank/log_prob/chain_of_shift_of_scale_matvec_linear_operator_2/inverse_log_det_jacobian/scale_matvec_linear_operator/inverse_log_det_jacobian/LinearOperatorLowRankUpdate/log_abs_det/MatrixDeterminant}}){{node mcmc_sample_chain/trace_scan/while/smart_for_loop/while/dual_averaging_step_size_adaptation___init__/_one_step/mh_one_step/phmc_kernel_one_step/compute_log_acceptance_correction/mcmc_sample_chain_trace_scan_while_smart_for_loop_while_dual_averaging_step_size_adaptation___init____one_step_mh_one_step_phmc_kernel_one_step_JointDistributionSequential/log_prob/mcmc_sample_chain_trace_scan_while_smart_for_loop_while_dual_averaging_step_size_adaptation___init____one_step_mh_one_step_phmc_kernel_one_step_BatchBroadcastmcmc_sample_chain_trace_scan_while_smart_for_loop_while_BatchBroadcastMultivariateNormalDiagPlusLowRank_4/log_prob/mcmc_sample_chain_trace_scan_while_smart_for_loop_while_BatchBroadcastMultivariateNormalDiagPlusLowRank_4/log_prob/MultivariateNormalDiagPlusLowRank/log_prob/chain_of_shift_of_scale_matvec_linear_operator_2/inverse_log_det_jacobian/scale_matvec_linear_operator/inverse_log_det_jacobian/LinearOperatorLowRankUpdate/log_abs_det/MatrixDeterminant}}

My packages are:

Python 3.8.10
tensorboard             2.5.0
tensorboard-data-server 0.6.1
tensorboard-plugin-wit  1.8.0
tensorflow-estimator    2.5.0
tensorflow-gpu          2.5.0
tensorflow-probability  0.13.0
srvasude commented 2 years ago

@SiegeLordEx Is it possible to take a look at this? The underlying issue doesn't trigger any more (since I believe there is an XLA registration of the op), but it seems like I get an error from tfp.mcmc.