google / objax

Apache License 2.0
768 stars 77 forks source link

Accessing variable of a vectorized module #217

Closed noashin closed 3 years ago

noashin commented 3 years ago

Hello, I would like to know if it is possible to access and change the values of variables of a module that is parallelized. My use case is the following - I want to optimize the parameters of a kernel and I can compute exactly the derivative of the loss function with respect to the kernel parameters. I want to vectorize the kernel for faster computation, but I do not know how to access the value of the kernel parameters once the kernel module is vectorized.

def exp_quadratic(x1, x2):
      return np.exp(- np.sum((x1 - x2) ** 2))

class Kernel(objax.Module):
    def __call__(self, X1, X2):
        return self.K(X1, X2)

    def K(self, X1, X2):
        raise NotImplementedError('kernel function not implemented')

class RBFKernel(Kernel):
  def __init__(self, variance, lengthscale):
    self.lengthscale = objax.StateVar(np.array(lengthscale))
    self.variance = objax.StateVar(np.array(variance))

  def K(self, X1, X2):
    return self.variance.value * exp_quadratic(X1 / (np.sqrt(2) * \
                                                        self.lengthscale.value),
                                            X2 / (np.sqrt(2) * self.lengthscale.value))

class Inference(objax.Module):

  def __init__(self, kernel):
    self.kernel = objax.Vectorize(objax.Vectorize(kernel, batch_axis=(0, None)), batch_axis=(None, 0))

  def update(self, data):
    exact_grad = self.exact_grad(self.kernel)
    # update the variance and lengthscale of self.kernel following exact_grad
    ...
AlexeyKurakin commented 3 years ago

Generally speaking you can access all variables of any module by just calling module.vars(). Is there any specific which prevent you from using .vars() method?

AlexeyKurakin commented 3 years ago

Closing now because of inactivity