wesselb / neuralprocesses

A framework for composing Neural Processes in Python
https://wesselb.github.io/neuralprocesses
MIT License
73 stars 11 forks source link

[FR] Allow `nps.loglik`, `nps.ar_loglik` and `nps.ar_predict` to work with numpy arrays #7

Open tom-andersson opened 1 year ago

tom-andersson commented 1 year ago

Currently, with a TensorFlow backend, I believe all the array inputs to nps.loglik, nps.ar_loglik and nps.ar_sample need to be TensorFlow tensors, e.g. tf.EagerTensors. It would be convenient if users with numpy arrays did not have to explicitly cast all their data to TensorFlow before calling these methods. This dtype behaviour is supported with model.__call__ functionality in neuralprocesses so my hunch is that this shouldn't be too tricky...

cc @stratisMarkou

wesselb commented 1 year ago

Hey @tom-andersson :)

Is something like this what you're after?

from plum import dispatch, add_conversion_method

import lab.tensorflow as B
import tensorflow as tf

import neuralprocesses.tensorflow as nps

@dispatch
def to_tf(*xs):
    return tuple(to_tf(x) for x in xs)

@dispatch
def to_tf(x: list):
    return [to_tf(xi) for xi in x]

@dispatch
def to_tf(x: tuple):
    return tuple(to_tf(xi) for xi in x)

@dispatch
def to_tf(x):
    # Move to GPU here? Do more?
    return B.cast(tf.float32, x)

@nps.loglik.dispatch
def loglik(model: nps.Model, contexts: list, xt: B.NPNumeric, yt: B.NPNumeric):
    return B.to_numpy(loglik(model, to_tf(contexts), to_tf(xt), to_tf(yt)))

model = nps.construct_convgnp(dtype=tf.float32)

xc = B.randn(16, 1, 10)
yc = B.randn(16, 1, 10)
xt = B.randn(16, 1, 10)
yt = B.randn(16, 1, 10)

print(nps.loglik(model, [(xc, yc)], xt, yt))  # Returns a NumPy thing
tom-andersson commented 1 year ago

I think this would do the job nicely @wesselb! Let me know when you've pushed it and I can update backends and test it out.

tom-andersson commented 1 year ago

I don't think more is needed than the tf.cast you propose - this should fix the dtype errors when running the nps.loglik (etc) methods. The only issue I can foresee is hard-coding float32s...