Hello, I would like to know if it is possible to access and change the values of variables of a module that is parallelized.
My use case is the following - I want to optimize the parameters of a kernel and I can compute exactly the derivative of the loss function with respect to the kernel parameters. I want to vectorize the kernel for faster computation, but I do not know how to access the value of the kernel parameters once the kernel module is vectorized.
def exp_quadratic(x1, x2):
return np.exp(- np.sum((x1 - x2) ** 2))
class Kernel(objax.Module):
def __call__(self, X1, X2):
return self.K(X1, X2)
def K(self, X1, X2):
raise NotImplementedError('kernel function not implemented')
class RBFKernel(Kernel):
def __init__(self, variance, lengthscale):
self.lengthscale = objax.StateVar(np.array(lengthscale))
self.variance = objax.StateVar(np.array(variance))
def K(self, X1, X2):
return self.variance.value * exp_quadratic(X1 / (np.sqrt(2) * \
self.lengthscale.value),
X2 / (np.sqrt(2) * self.lengthscale.value))
class Inference(objax.Module):
def __init__(self, kernel):
self.kernel = objax.Vectorize(objax.Vectorize(kernel, batch_axis=(0, None)), batch_axis=(None, 0))
def update(self, data):
exact_grad = self.exact_grad(self.kernel)
# update the variance and lengthscale of self.kernel following exact_grad
...
Generally speaking you can access all variables of any module by just calling module.vars().
Is there any specific which prevent you from using .vars() method?
Hello, I would like to know if it is possible to access and change the values of variables of a module that is parallelized. My use case is the following - I want to optimize the parameters of a kernel and I can compute exactly the derivative of the loss function with respect to the kernel parameters. I want to vectorize the kernel for faster computation, but I do not know how to access the value of the kernel parameters once the kernel module is vectorized.