Open tom-andersson opened 1 year ago
Hey @tom-andersson :)
Is something like this what you're after?
from plum import dispatch, add_conversion_method
import lab.tensorflow as B
import tensorflow as tf
import neuralprocesses.tensorflow as nps
@dispatch
def to_tf(*xs):
return tuple(to_tf(x) for x in xs)
@dispatch
def to_tf(x: list):
return [to_tf(xi) for xi in x]
@dispatch
def to_tf(x: tuple):
return tuple(to_tf(xi) for xi in x)
@dispatch
def to_tf(x):
# Move to GPU here? Do more?
return B.cast(tf.float32, x)
@nps.loglik.dispatch
def loglik(model: nps.Model, contexts: list, xt: B.NPNumeric, yt: B.NPNumeric):
return B.to_numpy(loglik(model, to_tf(contexts), to_tf(xt), to_tf(yt)))
model = nps.construct_convgnp(dtype=tf.float32)
xc = B.randn(16, 1, 10)
yc = B.randn(16, 1, 10)
xt = B.randn(16, 1, 10)
yt = B.randn(16, 1, 10)
print(nps.loglik(model, [(xc, yc)], xt, yt)) # Returns a NumPy thing
I think this would do the job nicely @wesselb! Let me know when you've pushed it and I can update backends
and test it out.
I don't think more is needed than the tf.cast
you propose - this should fix the dtype errors when running the nps.loglik
(etc) methods. The only issue I can foresee is hard-coding float32s...
Currently, with a TensorFlow backend, I believe all the array inputs to
nps.loglik
,nps.ar_loglik
andnps.ar_sample
need to be TensorFlow tensors, e.g.tf.EagerTensor
s. It would be convenient if users withnumpy
arrays did not have to explicitly cast all their data to TensorFlow before calling these methods. This dtype behaviour is supported withmodel.__call__
functionality inneuralprocesses
so my hunch is that this shouldn't be too tricky...cc @stratisMarkou