google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.26k stars 227 forks source link

Memory Constraint (Approximation Available?) #28

Open ginward opened 4 years ago

ginward commented 4 years ago

Hello,

I currently have a dataset that has 1246064 observations and 94 features. It is my understanding that the GP process would have to generate a kernel size of 1246064 * 1246064, and I am not sure if that is the reason that I am currently running into the following memory error:

RuntimeError                              Traceback (most recent call last)
<ipython-input-81-5095b4194ced> in <module>()
     29     r_mean, r_covariance = nt.predict.gp_inference(
     30         kernel_fn, z_train, r_train, z_test,
---> 31         diag_reg=1e-4, get='ntk', compute_cov=True)
     32     r_mean = np.reshape(r_mean, (-1,))[np.newaxis, ...]
     33     out_rsq_list.append((r_test.detach().cpu().numpy(), r_test.detach().cpu().numpy()))

8 frames
/usr/local/lib/python3.6/dist-packages/jaxlib/xla_client.py in compile(self, c_computation, compile_options)
    148                                         compile_options.argument_layouts,
    149                                         options, self.client,
--> 150                                         compile_options.device_assignment)
    151 
    152   def get_default_device_assignment(self, num_replicas, num_partitions=None):

RuntimeError: Resource exhausted: Out of memory while trying to allocate 6210718745600 bytes.

I was wondering if there is a way around this (for example, to create a kernel approximation of some sort, similar to this one.

Thanks!

ginward commented 4 years ago

Or rather, is there a way to train with mini-batches?

jaehlee commented 4 years ago

Hello ginward,

currently we have a mini-batch function for constructing kernel elements but not for inference. However, even with mini-batch our library still constructs the full kernel and with 1.25M data it will need 6.2TB of memory. (Since kernel is symmetric it should take 3TB but the library computes full matrix at the moment.)

Moreover, our predict function performs inference either by Cholesky decomposition or eigen decomposition on the full kernel which scales as O(n**3). With sizable RAM, one could try dataset size of ~100k, however scaling to larger dataset would be challenging.

So there's few ways of scaling up GP or Kernel computations. Our library does not support these methods yet, but we are interested in using our kernels with these methods for scaling up kernel computations.

1) Exact Gaussian Processes on a Million Data Points, Wang et al., https://arxiv.org/abs/1903.08114 2) Kernel machines that adapt to GPUs for effective large batch training, Ma and Belkin, https://arxiv.org/abs/1806.06144

These methods avoid constructing full kernel matrix but only computing ones that is required. Also they are iterative methods in nature. As far as I can tell, kernel approximation methods are another way to scale up GP computation. One could try replacing kernel methods of those libraries with our kernel function. For reference:

3) Gaussian processes for big data, Hensman et al., https://arxiv.org/abs/1309.6835 4) When Gaussian process meets big data: A review of scalable GPs, Liu et al., https://arxiv.org/abs/1807.01065

sschoenholz commented 4 years ago

Just to add to Jaehoon's reply, one thing we are interested in testing out is integration with the excellent GPyTorch package (https://gpytorch.ai/) which can scale GP inference to 1M+ datapoints. We haven't explored this direction too much ourselves, but I think it should be doable especially since JAX supports DLPack (https://jax.readthedocs.io/en/latest/jax.dlpack.html) which should enable zero cost sharing of tensors between JAX and PyTorch. In principal, therefore, you should be able to construct a GP kernel using minibatching with NT and then use GPyTorch to do inference.

If you do decide to explore this direction, please keep us posted. I would be very curious to hear about your experience and if you run into trouble we would love to help you resolve problems as you come across them.

ginward commented 4 years ago

Thanks @sschoenholz and @jaehlee ! I will explore and let you know.

ginward commented 4 years ago

@sschoenholz I assume that, even if I use GPyTorch to access the NT mini-batch kernel, I would still need 6TB memory space, as pointed out by @jaehlee ?

jaehlee commented 4 years ago

@ginward GPyTorch's scalable version will probably compute elements of kernels as needed which won't be 6TB at any given instance. This will result in computing same kernel element multiple times. If you are using either FC kernels or CNN kernels without pooling, overhead of kernel computation wouldn't be so bad. If you use pooling layer or kernel that keeps tract of feature-feature covariance, you maybe overwhelmed by kernel computation. For the latter case, one should cache computation on disk and reuse whenever possible.

ginward commented 4 years ago

Thanks!

JianxiangFENG commented 3 years ago

Hello guys,

I also have the reuqirement to compute covariance matrix over 300k datapoints. Have you succeeded obtaining something working in the end? @ginward

@jaehlee @sschoenholz There is another question about the multi-dimensional output, if I have a not small number of class like 50, besides treating each output dimension as independent for a block diagonal covariance matrix, are there any other ways to alleviate the computational and storage burden?

Best, Jianxiang