Currently, calling B.jit on a logpdf-based objective works for a model with a single context set and single target set. However, a ValueError is raised in the case where a model takes multiple context sets as input. See MWE on Google Colab and below:
import neuralprocesses.tensorflow as nps
import tensorflow as tf
import lab.tensorflow as B
import time
def test_jit(n_context_sets=1):
model = nps.construct_convgnp(dim_x=1, dim_yc=(1,)*n_context_sets, dim_yt=1)
def objective(xt, yt, *context_data):
"""
Context data to be passed as xc1, yc1, xc2, yc2, ...
"""
# Convert to list of (x, y) tuples format
context_data = [(context_data[2*i], context_data[2*i+1]) for i in range(n_context_sets)]
return -model(context_data, xt).logpdf(yt)
def test(objective):
"""Generate random data to test objective"""
xcs = [B.randn(tf.float32, 16, 1, 10) for i in range(n_context_sets)]
ycs = [B.randn(tf.float32, 16, 1, 10) for i in range(n_context_sets)]
context_data = []
for i in range(n_context_sets):
context_data.append(xcs[i])
context_data.append(ycs[i])
xt = B.randn(tf.float32, 16, 1, 20)
yt = B.randn(tf.float32, 16, 1, 20)
return objective(xt, yt, *context_data)
its = 10
s = time.time()
for _ in range(its):
test(objective)
print(f"Without JIT ({n_context_sets} context sets):", (time.time() - s) / its)
objective_compiled = B.jit(objective)
test(objective_compiled) # Run once to compile.
s = time.time()
for _ in range(its):
test(objective_compiled)
print(f"With JIT ({n_context_sets} context sets):", (time.time() - s) / its)
test_jit(n_context_sets=1)
test_jit(n_context_sets=2)
Running the above produces:
Without JIT (1 context sets): 0.27810795307159425
With JIT (1 context sets): 0.027799010276794434
Without JIT (2 context sets): 0.2577114820480347
However, at the point of running the model with two context sets with JIT, it raises a ValueError:
Currently, calling
B.jit
on alogpdf
-based objective works for a model with a single context set and single target set. However, aValueError
is raised in the case where a model takes multiple context sets as input. See MWE on Google Colab and below:Running the above produces:
However, at the point of running the model with two context sets with JIT, it raises a
ValueError
: