tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.26k stars 1.1k forks source link

Avoid creating new computation nodes in `kernel_results` from MCMC kernels #498

Open rzu512 opened 5 years ago

rzu512 commented 5 years ago

The one_step method of some MCMC kernels seem to return a kernel_result that contains new computation nodes.

If I run the following loop in graph mode, I would create new computation nodes in every iteration.

  1. boot strap and get "kernel result 0"
  2. Use "kernel result 0" to run the one_step method to get "kernel result 1"
  3. Use "kernel result 1" to run the one_step method to get "kernel result 2" ...

How to avoid that?

permalink

      return result_state, NUTSKernelResults(
          next_.target_log_prob, next_.grads_target_log_prob,
          leapfrogs_taken + new_leapfrogs,
          leapfrogs_computed + tf.math.reduce_max(input_tensor=new_leapfrogs))
SiegeLordEx commented 5 years ago

Could you explain why you'd want that? TFP MCMC kernels use kernel_results to reuse some computation from previous steps for efficiency. If you're really bothered by that, you can write a new TransitionKernel that avoids doing that:

class NoKR(tfp.mcmc.TransitionKernel):
    def __init__(self, kernel):
        self._kernel = kernel

    def is_calibrated(self):
        return self._kernel.is_calibrated()

    def one_step(self, state, _):
        state, _ = self._kernel.one_step(state, self._kernel.bootstrap_results(state))
        return state, ()

    def boostrap_results(self, state):
        return ()

but that will be very inefficient for some kernels.

rzu512 commented 5 years ago

Would the new computation nodes eventually use up all the memory?

SiegeLordEx commented 5 years ago

They would if you iterated through the kernels via a Python while loop. But if you use tfp.mcmc.sample_chain or tf.while_loop, they'll create the loops in the graphs themselves, which means that the computation nodes corresponding to kernel results will only be created once.

rzu512 commented 5 years ago

If I run sample_chain with a large num_results, it would store the states in every step on the GPU. That can fill up the memory of the GPU.

If I don't run sample_chain with a large num_results, then I would need to run sample_chain multiple times within a loop. The sample_chain function still needs a previous_kernel_results if I don't want to restart the chains in every iteration.

SiegeLordEx commented 5 years ago

See my answer to the other bug: https://github.com/tensorflow/probability/issues/497#issuecomment-511883886. If you're only interested in the final state, then you can use tf.while_loop, it'll be both compute and memory efficient.

junpenglao commented 5 years ago

Instead of writing a custom tf.while_loop, the easier thing to do is set the number of sample to 1 and put the rest of the samples to burnin:

hmc_kernel = tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob, step_size, leapfrogs, seed=seed)

answer = tfp.mcmc.sample_chain(
          num_results=1,
          num_burnin_steps=num_steps - 1,
          current_state=[inital_state],
          kernel=hmc_kernel,
          parallel_iterations=1)