google / qhbm-library

Quantum Hamiltonian-Based Models built on TensorFlow Quantum
https://qhbm-library.readthedocs.io/en/latest/
Apache License 2.0
40 stars 15 forks source link

Separate EBM #111

Closed zaqqwerty closed 2 years ago

zaqqwerty commented 2 years ago

Separate ebm module into energy_model and energy_infer submodules.

First part of resolving #110. Theme is separating model code and inference code.

zaqqwerty commented 2 years ago

Thanks for the detailed review!

So, my understanding is that EBMs fit poorly into the tfp.distribution framework. The intention of tfp.distributions.Distribution is that subclasses implement efficient and non-approximate methods to compute properties, in particular it requires that subclasses implement _sample_n. For EBMs we rather assume the opposite, that generally distributions will need approximate or inefficient methods for sampling, with Bernoulli being the special case.

The part of TFP that I think is much closer to our use case is tfp.mcmc. The energy_model module is meant to help create callables which act like energy functions for use with approximate sampling methods. The rough analogy is: BitstringDistribution is to BitstringSampler what target_log_prob_fn is to Hamiltonian Monte Carlo. Making the energy functions be a class instead of just a function is simply to let them carry around state, the trainable variables.

Do these general statements make sense? Want to make sure we agree on the general principles of the design before I respond more granularly.

farice commented 2 years ago

So, my understanding is that EBMs fit poorly into the tfp.distribution framework. The intention of tfp.distributions.Distribution is that subclasses implement efficient and non-approximate methods to compute properties, in particular it requires that subclasses implement _sample_n. For EBMs we rather assume the opposite, that generally distributions will need approximate or inefficient methods for sampling, with Bernoulli being the special case.

Agree, only natural inheritance of tfp.distributions.Distribution I see so far are those inference special cases

The part of TFP that I think is much closer to our use case is tfp.mcmc. The energy_model module is meant to help create callables which act like energy functions for use with approximate sampling methods. The rough analogy is: BitstringDistribution is to BitstringSampler what target_log_prob_fn is to Hamiltonian Monte Carlo. Making the energy functions be a class instead of just a function is simply to let them carry around state, the trainable variables.

Likewise agree. And inheriting tfp.mcmc will eventually make sense on the inference side where the qhbm model is used to override target_log_prob_fn. Hence, there remains no useful TFP inheritance model-side in my mental model

farice commented 2 years ago

For inference, I would consider something like this (needs to be cleaned up / organized / abstracted of course):

class AnalyticInferenceLayer(tf.keras.layers.Layer):
  """Sampler which calculates all probabilities and samples as categorical.
  Compares very abstractly to e.g. tfp.mcmc.HamiltonianMonteCarlo which
  constructs an inference object.
  """

  def __init__(self, bit_string_energy: BitstringEnergy, name=None):
    """Instantiates an AnalyticInference object.

    Internally, this class saves all possible bitstrings as a tensor,
    whose energies are calculated relative to input distributions for sampling.
    """
    super().__init__(name=name)
    self._bit_string_energy = bit_string_energy
    # should beta live here instead of losses?
    self._partition_estimator = tfp.layers.DistributionLambda(
        make_distribution_fn=lambda t: tfd.Categorical(logits=-1 * t)
    )

  def call(self, inputs):
    x = tf.squeeze(self._bit_string_energy(inputs))
    return self._partition_estimator(x)

# Other inference objects will perform selective sampling of bit-strings, still
# working via forward passes.
def analytic_infer(analytic_inference_layer):
  """ Intuitively, compares to tfp.mcmc.sample_chain.
  However, in this case, we are doing exact inference and so can return a
  tensor-coercible distribution object directly which is stronger than sampling
  access. So, here, inference corresponds to a single forward pass.
  """

  all_bitstrings = tf.constant(
      list(itertools.product([0, 1], repeat=num_bits)), dtype=tf.int8)
  return analytic_inference_layer(all_bitstrings)

This enables

ail = AnalyticInferenceLayer(be)
cat = analytic_infer(ail)
tf.convert_to_tensor(cat) # produces a categorical sample when forced to coalesce
cat.sample(10) # produces many samples
zaqqwerty commented 2 years ago

AnalyticInferenceLayer overall lgtm!

Nice, I'll switch the others out to this structure then!

farice commented 2 years ago

LGTM. Please check the open comments before submitting! I tried to resolve the outdated ones