wesselb / neuralprocesses

A framework for composing Neural Processes in Python
MIT License
76 stars 12 forks source link

[BUG] `ValueError` when using TensorFlow graph mode / `B.jit` for models with multiple context sets #12

Open tom-andersson opened 1 year ago

tom-andersson commented 1 year ago

Currently, calling B.jit on a logpdf-based objective works for a model with a single context set and single target set. However, a ValueError is raised in the case where a model takes multiple context sets as input. See MWE on Google Colab and below:

import neuralprocesses.tensorflow as nps
import tensorflow as tf
import lab.tensorflow as B
import time

def test_jit(n_context_sets=1):
  model = nps.construct_convgnp(dim_x=1, dim_yc=(1,)*n_context_sets, dim_yt=1)

  def objective(xt, yt, *context_data):
      Context data to be passed as xc1, yc1, xc2, yc2, ...
      # Convert to list of (x, y) tuples format
      context_data = [(context_data[2*i], context_data[2*i+1]) for i in range(n_context_sets)]
      return -model(context_data, xt).logpdf(yt)

  def test(objective):
      """Generate random data to test objective"""
      xcs = [B.randn(tf.float32, 16, 1, 10) for i in range(n_context_sets)]
      ycs = [B.randn(tf.float32, 16, 1, 10) for i in range(n_context_sets)]
      context_data = []
      for i in range(n_context_sets):
      xt = B.randn(tf.float32, 16, 1, 20)
      yt = B.randn(tf.float32, 16, 1, 20)
      return objective(xt, yt, *context_data)

  its = 10

  s = time.time()
  for _ in range(its):
  print(f"Without JIT ({n_context_sets} context sets):", (time.time() - s) / its)
  objective_compiled = B.jit(objective)
  test(objective_compiled)  # Run once to compile.
  s = time.time()
  for _ in range(its):
  print(f"With JIT ({n_context_sets} context sets):", (time.time() - s) / its)


Running the above produces:

Without JIT (1 context sets): 0.27810795307159425
With JIT (1 context sets): 0.027799010276794434
Without JIT (2 context sets): 0.2577114820480347

However, at the point of running the model with two context sets with JIT, it raises a ValueError:

[/usr/local/lib/python3.9/dist-packages/neuralprocesses/materialise.py](https://localhost:8080/#) in _merge(z1, z2)
     70 def _merge(z1: B.Numeric, z2: B.Numeric):
     71     if B.jit_to_numpy(B.mean(B.abs(z1 - z2))) > B.epsilon:
---> 72         raise ValueError("Cannot merge inputs.")
     73     return z1

ValueError: Cannot merge inputs.