Learn2Solve / gp-scaling

0 stars 0 forks source link

Pilot study #1

Open Learn2Solve opened 7 months ago

Learn2Solve commented 7 months ago

GPNN trains a model using a randomly subsampled dataset to extract the optimal hyperparameters. Then for each test data, they use the nearest neighbor to get prediction. This work aims to study the scaling rule for GP. Similar to muP in NNs, there should be similar scaling rule for GP. With careful model setting up, the optimal hyperparameters should be directly transferred to large scale models with proper scaling.

Learn2Solve commented 7 months ago

References: Code: https://github.com/ant-stephenson/gpnn-experiments/ Paper: https://arxiv.org/pdf/2306.14731.pdf

Learn2Solve commented 7 months ago
Dataset n d NLL (GPNN) NLL (ExactGP) RMSE (GPNN) RMSE (ExactGP) Time (ExactGP) Compute (ExactGP) Time (GPNN)
Poletele 4.6e+03 19 -0.214 ± 0.019 -0.18 0.195 ± 0.0042 0.151 41.5 1-GPU exact 28.8 ± 0.22
Bike 1.4e+04 13 0.953 ± 0.013 0.119 0.624 ± 0.0079 0.220 41.2 1_GPU exact 28.4 ± 0.12
Protein 3.6e+04 9 1.01 ± 0.0016 1.018 0.666 ± 0.0014 0.536 47.9 1-GPU exact 27.7 ± 0.19
Ctslice 4.2e+04 378 -1.26 ± 0.01 -0.073 0.132 ± 0.00062 0.218 129.6 1-GPU SGPR 76.1 ± 4.6
Road3D 3.4e+05 2 0.371 ± 0.004 0.909 0.351 ± 0.0014 0.101 720.5 8-GPU SGPR 27.9 ± 1.3
Song 4.6e+05 90 1.18 ± 0.0045 1.21 0.787 ± 0.0045 0.803 253.4 8-GPU exact 138.0 ± 5.8
HouseE 1.6e+06 8 -1.56 ± 0.0065 -0.152 0.0506 ± 0.00072 0.055 4317.3 8-GPU exact 32.0 ± 0.34
Learn2Solve commented 7 months ago

TODOs:

Learn2Solve commented 7 months ago
Train Size Test Size Dimension NN Seed True Kernel Assumed Kernel True ks True ls True nv Assumed ks Assumed ls Assumed nv Vary Par MSE NLL MSCAL EMSE ENLL EMSCAL Learned ls Learned ks Learned nv
2000 200 3 100 42 RBF RBF 0.8 1.0 0.2 0.8 0.1 0.2 lenscale 0.358 0.947 0.565 0.382 0.906 0.871 0.159 1.062 0.137
2000 200 3 100 42 RBF RBF 0.8 1.0 0.2 0.8 0.680 0.2 lenscale 0.174 0.557 0.812 0.204 0.625 0.991 0.945 0.573 0.201
2000 200 3 100 42 RBF RBF 0.8 1.0 0.2 0.8 1.26 0.2 lenscale 0.173 0.549 0.836 0.202 0.617 0.969 0.964 1.053 0.193
2000 200 3 100 42 RBF RBF 0.8 1.0 0.2 0.8 1.84 0.2 lenscale 0.176 0.555 0.859 0.226 0.677 1.109 1.447 1.093 0.200
2000 200 3 100 42 RBF RBF 0.8 1.0 0.2 0.8 2.42 0.2 lenscale 0.179 0.561 0.874 0.203 0.621 1.002 1.981 1.098 0.237
2000 200 3 100 42 RBF RBF 0.8 1.0 0.2 0.8 3.0 0.2 lenscale 0.181 0.566 0.884 0.269 0.785 1.332 2.537 1.097 0.273
2000 200 3 100 42 RBF RBF 0.8 1.0 0.2 0.1 1.0 0.2 kernelscale 0.179 0.561 0.871 0.195 0.603 0.949 0.736 0.158 0.193
2000 200 3 100 42 RBF RBF 0.8 1.0 0.2 0.68 1.0 0.2 kernelscale 0.173 0.549 0.829 0.215 0.652 1.054 1.199 0.953 0.209
2000 200 3 100 42 RBF RBF 0.8 1.0 0.2 1.26 1.0 0.2 kernelscale 0.172 0.549 0.820 0.176 0.555 0.856 1.126 1.052 0.206
2000 200 3 100 42 RBF RBF 0.8 1.0 0.2 1.84 1.0 0.2 kernelscale 0.172 0.550 0.815 0.200 0.615 0.981 1.295 1.467 0.214
2000 200 3 100 42 RBF RBF 0.8 1.0 0.2 2.42 1.0 0.2 kernelscale 0.172 0.550 0.812 0.211 0.642 1.031 1.259 2.016 0.205
2000 200 3 100 42 RBF RBF 0.8 1.0 0.2 3.0 1.0 0.2 kernelscale 0.172 0.551 0.810 0.173 0.548 0.837 1.249 2.570 0.213
2000 200 3 100 42 RBF RBF 0.8 1.0 0.2 0.8 1.0 0.01 noisevar 0.174 6.596 15.861 0.204 6.540 15.550 0.715 1.106 0.016
2000 200 3 100 42 RBF RBF 0.8 1.0 0.2 0.8 1.0 0.208 noisevar 0.172 0.553 0.795 0.188 0.588 0.881 0.989 0.642 0.202
2000 200 3 100 42 RBF RBF 0.8 1.0 0.2 0.8 1.0 0.406 noisevar 0.173 0.692 0.413 0.249 0.780 0.602 0.973 0.981 0.267
2000 200 3 100 42 RBF RBF 0.8 1.0 0.2 0.8 1.0 0.604 noisevar 0.174 0.822 0.280 0.172 0.813 0.281 1.187 0.683 0.408
2000 200 3 100 42 RBF RBF 0.8 1.0 0.2 0.8 1.0 0.802 noisevar 0.175 0.929 0.212 0.237 0.960 0.290 1.089 0.887 0.558
2000 200 3 100 42 RBF RBF 0.8 1.0 0.2 0.8 1.0 1.0 noisevar 0.176 1.018 0.171 0.219 1.033 0.216 1.062 0.754 0.714
2000 200 3 100 42 RBF Matern 0.8 1.0 0.2 0.8 0.1 0.2 lenscale 0.334 0.976 0.451 0.274 0.825 0.664 0.160 0.557 0.127
2000 200 3 100 42 RBF Matern 0.8 1.0 0.2 0.8 0.680 0.2 lenscale 0.175 0.578 0.720 0.225 0.670 0.999 0.948 0.562 0.193
2000 200 3 100 42 RBF Matern 0.8 1.0 0.2 0.8 1.26 0.2 lenscale 0.172 0.553 0.783 0.182 0.574 0.860 1.621 0.601 0.219
2000 200 3 100 42 RBF Matern 0.8 1.0 0.2 0.8 1.84 0.2 lenscale 0.172 0.549 0.810 0.218 0.655 1.024 1.870 0.973 0.194
2000 200 3 100 42 RBF Matern 0.8 1.0 0.2 0.8 2.42 0.2 lenscale 0.173 0.550 0.829 0.221 0.665 1.046 2.010 1.081 0.195
2000 200 3 100 42 RBF Matern 0.8 1.0 0.2 0.8 3.0 0.2 lenscale 0.175 0.553 0.844 0.221 0.665 1.061 2.540 1.096 0.197
2000 200 3 100 42 RBF Matern 0.8 1.0 0.2 0.1 1.0 0.2 kernelscale 0.177 0.559 0.844 0.208 0.631 0.969 0.857 0.157 0.207
2000 200 3 100 42 RBF Matern 0.8 1.0 0.2 0.68 1.0 0.2 kernelscale 0.172 0.558 0.770 0.197 0.611 0.925 1.323 0.485 0.210
2000 200 3 100 42 RBF Matern 0.8 1.0 0.2 1.26 1.0 0.2 kernelscale 0.173 0.567 0.745 0.179 0.568 0.808 1.334 0.933 0.201
2000 200 3 100 42 RBF Matern 0.8 1.0 0.2 1.84 1.0 0.2 kernelscale 0.174 0.574 0.729 0.259 0.741 1.128 1.331 1.447 0.200
2000 200 3 100 42 RBF Matern 0.8 1.0 0.2 2.42 1.0 0.2 kernelscale 0.175 0.581 0.716 0.209 0.635 0.897 1.334 1.983 0.156
2000 200 3 100 42 RBF Matern 0.8 1.0 0.2 3.0 1.0 0.2 kernelscale 0.176 0.588 0.705 0.223 0.665 0.937 1.336 2.537 0.172
2000 200 3 100 42 RBF Matern 0.8 1.0 0.2 0.8 1.0 0.01 noisevar 0.191 4.894 12.023 0.261 3.960 9.592 0.716 1.100 0.016
2000 200 3 100 42 RBF Matern 0.8 1.0 0.2 0.8 1.0 0.208 noisevar 0.172 0.564 0.736 0.217 0.658 0.975 1.317 0.573 0.212
2000 200 3 100 42 RBF Matern 0.8 1.0 0.2 0.8 1.0 0.406 noisevar 0.173 0.708 0.389 0.209 0.735 0.493 1.331 0.564 0.267
2000 200 3 100 42 RBF Matern 0.8 1.0 0.2 0.8 1.0 0.604 noisevar 0.173 0.837 0.266 0.243 0.877 0.388 1.336 0.560 0.408
2000 200 3 100 42 RBF Matern 0.8 1.0 0.2 0.8 1.0 0.802 noisevar 0.174 0.942 0.203 0.199 0.943 0.241 1.336 0.560 0.557
2000 200 3 100 42 RBF Matern 0.8 1.0 0.2 0.8 1.0 1.0 noisevar 0.175 1.030 0.165 0.190 1.025 0.185 1.332 0.563 0.714
2000 200 3 100 42 RBF Exp 0.8 1.0 0.2 0.8 0.1 0.2 lenscale 0.301 0.996 0.358 0.318 0.900 0.592 0.160 0.556 0.126
2000 200 3 100 42 RBF Exp 0.8 1.0 0.2 0.8 0.680 0.2 lenscale 0.182 0.669 0.512 0.204 0.659 0.723 0.948 0.561 0.132
2000 200 3 100 42 RBF Exp 0.8 1.0 0.2 0.8 1.26 0.2 lenscale 0.177 0.616 0.596 0.195 0.617 0.746 1.623 0.563 0.137
2000 200 3 100 42 RBF Exp 0.8 1.0 0.2 0.8 1.84 0.2 lenscale 0.175 0.594 0.640 0.241 0.706 0.972 2.251 0.566 0.148
2000 200 3 100 42 RBF Exp 0.8 1.0 0.2 0.8 2.42 0.2 lenscale 0.174 0.582 0.670 0.198 0.622 0.837 2.848 0.570 0.154
2000 200 3 100 42 RBF Exp 0.8 1.0 0.2 0.8 3.0 0.2 lenscale 0.173 0.574 0.691 0.198 0.615 0.819 3.313 0.648 0.189
2000 200 3 100 42 RBF Exp 0.8 1.0 0.2 0.1 1.0 0.2 kernelscale 0.178 0.566 0.792 0.213 0.648 0.938 1.288 0.159 0.193
2000 200 3 100 42 RBF Exp 0.8 1.0 0.2 0.68 1.0 0.2 kernelscale 0.177 0.621 0.588 0.226 0.677 0.864 1.328 0.471 0.136
2000 200 3 100 42 RBF Exp 0.8 1.0 0.2 1.26 1.0 0.2 kernelscale 0.182 0.678 0.499 0.197 0.650 0.657 1.335 0.933 0.129
2000 200 3 100 42 RBF Exp 0.8 1.0 0.2 1.84 1.0 0.2 kernelscale 0.185 0.729 0.439 0.226 0.718 0.656 1.337 1.441 0.128
2000 200 3 100 42 RBF Exp 0.8 1.0 0.2 2.42 1.0 0.2 kernelscale 0.188 0.774 0.394 0.263 0.800 0.695 1.340 1.978 0.127
2000 200 3 100 42 RBF Exp 0.8 1.0 0.2 3.0 1.0 0.2 kernelscale 0.191 0.815 0.358 0.232 0.774 0.529 1.341 2.533 0.127
2000 200 3 100 42 RBF Exp 0.8 1.0 0.2 0.8 1.0 0.01 noisevar 0.207 0.996 2.505 0.251 0.813 1.559 0.726 1.085 0.016
2000 200 3 100 42 RBF Exp 0.8 1.0 0.2 0.8 1.0 0.208 noisevar 0.178 0.639 0.549 0.233 0.699 0.842 1.328 0.563 0.141
2000 200 3 100 42 RBF Exp 0.8 1.0 0.2 0.8 1.0 0.406 noisevar 0.175 0.777 0.323 0.193 0.755 0.398 1.338 0.558 0.266
2000 200 3 100 42 RBF Exp 0.8 1.0 0.2 0.8 1.0 0.604 noisevar 0.175 0.895 0.232 0.206 0.882 0.300 1.339 0.557 0.408
2000 200 3 100 42 RBF Exp 0.8 1.0 0.2 0.8 1.0 0.802 noisevar 0.175 0.991 0.182 0.216 0.983 0.241 1.339 0.557 0.557
2000 200 3 100 42 RBF Exp 0.8 1.0 0.2 0.8 1.0 1.0 noisevar 0.176 1.072 0.150 0.238 1.073 0.216 1.340 0.557 0.714
Learn2Solve commented 7 months ago
Learn2Solve commented 7 months ago
Learn2Solve commented 6 months ago
Learn2Solve commented 6 months ago
Learn2Solve commented 6 months ago

The basic ideas for the paper:

  1. Test the calibration methods on UCI regression datasets to show its superiority and make corresponding improvement.
  2. Write the theoretical results for the simplest case using the existing results from the statistics community.
  3. Learn how to write a good paper and finish the first draft of the paper in one month and aim for 2024 NIPS.
Learn2Solve commented 5 months ago
Learn2Solve commented 5 months ago

The structure of the paper:

  1. Introduction: Misspeicification, calibration, mean and variance, small sample training, exact GP.
  2. Misspecification: kernel type and hyperparameter.
  3. Calibration: even the right kernel is chosen, some calibration is needed due to the small sample training, or epistemic error, distribution shift. Especially, calibration is necessary for UQ.
  4. Two approaches for GP: 1) scalable GP; 2) exact GP
  5. Numerical experiments
  6. Conclusions
Learn2Solve commented 4 months ago

Last section

  1. Use a first GP to fit the mean. Then use a second GP to fit the uncertainty for better uncertainty quantification.
  2. Add this experiment.