Open ginward opened 4 years ago
Or rather, is there a way to train with mini-batches?
Hello ginward,
currently we have a mini-batch function for constructing kernel elements but not for inference. However, even with mini-batch our library still constructs the full kernel and with 1.25M data it will need 6.2TB of memory. (Since kernel is symmetric it should take 3TB but the library computes full matrix at the moment.)
Moreover, our predict function performs inference either by Cholesky decomposition or eigen decomposition on the full kernel which scales as O(n**3). With sizable RAM, one could try dataset size of ~100k, however scaling to larger dataset would be challenging.
So there's few ways of scaling up GP or Kernel computations. Our library does not support these methods yet, but we are interested in using our kernels with these methods for scaling up kernel computations.
1) Exact Gaussian Processes on a Million Data Points, Wang et al., https://arxiv.org/abs/1903.08114 2) Kernel machines that adapt to GPUs for effective large batch training, Ma and Belkin, https://arxiv.org/abs/1806.06144
These methods avoid constructing full kernel matrix but only computing ones that is required. Also they are iterative methods in nature. As far as I can tell, kernel approximation methods are another way to scale up GP computation. One could try replacing kernel methods of those libraries with our kernel function. For reference:
3) Gaussian processes for big data, Hensman et al., https://arxiv.org/abs/1309.6835 4) When Gaussian process meets big data: A review of scalable GPs, Liu et al., https://arxiv.org/abs/1807.01065
Just to add to Jaehoon's reply, one thing we are interested in testing out is integration with the excellent GPyTorch package (https://gpytorch.ai/) which can scale GP inference to 1M+ datapoints. We haven't explored this direction too much ourselves, but I think it should be doable especially since JAX supports DLPack (https://jax.readthedocs.io/en/latest/jax.dlpack.html) which should enable zero cost sharing of tensors between JAX and PyTorch. In principal, therefore, you should be able to construct a GP kernel using minibatching with NT and then use GPyTorch to do inference.
If you do decide to explore this direction, please keep us posted. I would be very curious to hear about your experience and if you run into trouble we would love to help you resolve problems as you come across them.
Thanks @sschoenholz and @jaehlee ! I will explore and let you know.
@sschoenholz I assume that, even if I use GPyTorch to access the NT mini-batch kernel, I would still need 6TB memory space, as pointed out by @jaehlee ?
@ginward GPyTorch's scalable version will probably compute elements of kernels as needed which won't be 6TB at any given instance. This will result in computing same kernel element multiple times. If you are using either FC kernels or CNN kernels without pooling, overhead of kernel computation wouldn't be so bad. If you use pooling layer or kernel that keeps tract of feature-feature covariance, you maybe overwhelmed by kernel computation. For the latter case, one should cache computation on disk and reuse whenever possible.
Thanks!
Hello guys,
I also have the reuqirement to compute covariance matrix over 300k datapoints. Have you succeeded obtaining something working in the end? @ginward
@jaehlee @sschoenholz There is another question about the multi-dimensional output, if I have a not small number of class like 50, besides treating each output dimension as independent for a block diagonal covariance matrix, are there any other ways to alleviate the computational and storage burden?
Best, Jianxiang
Hello,
I currently have a dataset that has 1246064 observations and 94 features. It is my understanding that the GP process would have to generate a kernel size of 1246064 * 1246064, and I am not sure if that is the reason that I am currently running into the following memory error:
I was wondering if there is a way around this (for example, to create a kernel approximation of some sort, similar to this one.
Thanks!