cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.46k stars 545 forks source link

[Bug] Multitask-ExactGPs seem to not use mBCG algorithm as Singletask-ExactGPs do #2523

Open OliverAh opened 1 month ago

OliverAh commented 1 month ago

🐛 Bug

Might also explain #2306. Potential bug, but since the contributing guideline asks for specification of the issue I declared it as a bug right away. Please feel free to change, if it is not correct.

Using the gpytorch.settings.verbose_linalg(state=True) context revealed, that switching from single- to multi-task GPs, the linalg output changed from "CG" to "symeig". This is unexpected for me, because from the paper gardner2018gpytorch I do not see a reason why the "mBCG" algorithm (which I assume is called "CG" by the output shown below) should not be applicable in the multi-task case. Of course I could be missing that point, in this case please be so kind and point me to that.

To reproduce

Primarily, I adapted the GPyTorch Regression Tutorial (GPU) from the documentation. I wanted to make it convenient to easily switch back and forth between single- and multi-task GPs, shapes of tensors, and CPU/GPU, so I wrapped the code to reproduce in a function run_gpytorch(...), which is called with desired kwargs. The signature is

run_gpytorch(dims_in:int, dims_out:int, num_samples:int, device:str={'cpu''gpu'})

Code snippet to reproduce

import torch
import gpytorch
import numpy as np
def run_gpytorch(dims_in:int, dims_out:int, num_samples:int, device:str={'cpu''gpu'}):
    # Set context for mBCG and linalg debugging
    with gpytorch.settings.verbose_linalg(state=True) \
        ,gpytorch.settings.fast_computations(covar_root_decomposition=True, log_prob=True, solves=True):

        # Generate inputs for training and testing
        samples_train = np.linspace(start=[0.]*dims_in, stop=[1.]*dims_in, num=num_samples)
        samples_test = np.linspace(start=[0.]*dims_in, stop=[1.]*dims_in, num=np.floor(num_samples/2.67).astype(int))

        # For dims_in=1 reshaping is necessary, because gpytorch.models.ExactGP expects inputs as 1-D arrays (n,) [not (n,1)]
        if dims_in == 1:
            samples_train = samples_train.reshape(-1,)
            samples_test = samples_test.reshape(-1,)
        train_x = torch.tensor(samples_train).to(torch.float)
        test_x = torch.tensor(samples_test).to(torch.float)

        # Generate outputs for training (that the model should learn to predict)
        if (dims_in>1) and (dims_out>1):
            train_y = torch.stack([torch.sin(2*torch.pi*train_x[:,0])] * dims_out , 1).to(torch.float)
        elif (dims_in==1) and (dims_out>1):
            train_y = torch.stack([torch.sin(2*torch.pi*train_x     )] * dims_out , 1).to(torch.float)
        elif (dims_in>1) and (dims_out==1):
            train_y =              torch.sin(2*torch.pi*train_x[:,0])                 .to(torch.float)
        elif (dims_in==1) and (dims_out==1):
            train_y =              torch.sin(2*torch.pi*train_x     )                 .to(torch.float)

        print(f'Shape of train_x: {train_x.shape}')
        print(f'Shape of test_x:  {test_x.shape}')
        print(f'Shape of train_y: {train_y.shape}' + '\n')

        # Define class for single-/multitask GP model
        if dims_out == 1:
            print(f'Using single-task GP as dims_out = {dims_out}' + '\n')
            class ExactGPModel(gpytorch.models.ExactGP):
                def __init__(self, train_x, train_y, likelihood):
                    super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
                    self.mean_module = gpytorch.means.ConstantMean()
                    self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

                def forward(self, x):
                    mean_x = self.mean_module(x)
                    covar_x = self.covar_module(x)
                    return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

            # Instantiate single-task likelihood and GP model
            likelihood = gpytorch.likelihoods.GaussianLikelihood()
            model = ExactGPModel(train_x, train_y, likelihood)

        else:
            print(f'Using multi-task GP as dims_out = {dims_out}' + '\n')
            class MultitaskGPModel(gpytorch.models.ExactGP):
                def __init__(self, train_x, train_y, likelihood):
                    super(MultitaskGPModel, self).__init__(train_x, train_y, likelihood)
                    self.mean_module = gpytorch.means.MultitaskMean(
                        gpytorch.means.ConstantMean(), num_tasks=dims_out
                    )
                    self.covar_module = gpytorch.kernels.MultitaskKernel(
                        gpytorch.kernels.RBFKernel(), num_tasks=dims_out
                    )

                def forward(self, x):
                    mean_x = self.mean_module(x)
                    covar_x = self.covar_module(x)
                    return gpytorch.distributions.MultitaskMultivariateNormal(mean_x, covar_x)

            # Instantiate multi-task likelihood and GP model 
            likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=dims_out)
            model = MultitaskGPModel(train_x, train_y, likelihood)

        # Move data, GP model, and likelihood to gpu if desired
        if device == 'gpu':
            print('Move all structures to GPU since device=gpu' + '\n')
            train_x = train_x.cuda()
            test_x = test_x.cuda()
            train_y = train_y.cuda()
            model = model.cuda()
            likelihood = likelihood.cuda()
        else:
            print('Do not move structures to GPU since device=cpu' + '\n')

        # Switch to training mode
        model.train()
        likelihood.train()

        # Define optimizer and loss function
        optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
        mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

        # Train GP using training data
        print('Start training')
        training_iterations = 1
        for i in range(training_iterations):
            optimizer.zero_grad()
            output = model(train_x)
            loss = -mll(output, train_y)
            loss.backward()
            optimizer.step()
            print('  Finished training iteration %i/%i' % (i + 1, training_iterations))
        print('Finished training' + '\n')

        # Switch to evaluation mode, and probe trained GP using testing data
        model.eval()
        likelihood.eval()
        print('Start testing')
        with torch.no_grad():
            observed_model = model(test_x)
            print('  Testing: Finished evaluation')
            observed_pred = likelihood(observed_model)
            print('  Testing: Finished likelihood')
            mean = observed_pred.mean
            lower, upper = observed_pred.confidence_region()
        print('Finished testing' + '\n')

        return
# Create dict of inputs to run_gpytorch(...)
kwargs_profile = {'dims_in':1, 'dims_out':2, 'num_samples':1000, 'device':'gpu'}
run_gpytorch(**kwargs_profile)

Outputs

The code above generates the following output, where the usage of ... symeig ... is revealed:

Shape of train_x: torch.Size([1000])
Shape of test_x:  torch.Size([374])
Shape of train_y: torch.Size([1000, 2])

Using multi-task GP as dims_out = 2

Move all structures to GPU since device=gpu

LinAlg (Verbose) - DEBUG - Running symeig on a matrix of size torch.Size([1000, 1000]).
Start training
LinAlg (Verbose) - DEBUG - Running symeig on a matrix of size torch.Size([2, 2]).
LinAlg (Verbose) - DEBUG - Running symeig on a matrix of size torch.Size([1000, 1000]).
LinAlg (Verbose) - DEBUG - Running symeig on a matrix of size torch.Size([2, 2]).
LinAlg (Verbose) - DEBUG - Running symeig on a matrix of size torch.Size([1000, 1000]).
LinAlg (Verbose) - DEBUG - Running symeig on a matrix of size torch.Size([2, 2]).
c:\Users\{user}\micromamba\envs\gpytorch_mwe\Lib\site-packages\linear_operator\utils\interpolation.py:71: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at C:\cb\pytorch_1000000000000\work\torch\csrc\tensor\python_tensor.cpp:80.)
  summing_matrix = cls(summing_matrix_indices, summing_matrix_values, size)
c:\Users\{user}\micromamba\envs\gpytorch_mwe\Lib\site-packages\linear_operator\utils\interpolation.py:71: UserWarning: torch.sparse.SparseTensor(indices, values, shape, *, device=) is deprecated.  Please use torch.sparse_coo_tensor(indices, values, shape, dtype=, device=). (Triggered internally at C:\cb\pytorch_1000000000000\work\torch\csrc\utils\tensor_new.cpp:623.)
  summing_matrix = cls(summing_matrix_indices, summing_matrix_values, size)
LinAlg (Verbose) - DEBUG - Running symeig on a matrix of size torch.Size([1000, 1000]).
  Finished training iteration 1/1
Finished training

Start testing
LinAlg (Verbose) - DEBUG - Running symeig on a matrix of size torch.Size([2, 2]).
  Testing: Finished evaluation
  Testing: Finished likelihood
Finished testing

Expected Behavior

I expected the lines in the output starting with LinAlg (Verbose) ... to say something like LinAlg (Verbose) - DEBUG - Running CG ... instead of LinAlg (Verbose) - DEBUG - Running symeig...

This is because testing a single-task GP, run by (notice dims_out changed to 1)

kwargs_profile = {'dims_in':1, 'dims_out':1, 'num_samples':1000, 'device':'gpu'}
run_gpytorch(**kwargs_profile)

generates the following output, revealing ... CG ...:

Shape of train_x: torch.Size([1000])
Shape of test_x:  torch.Size([374])
Shape of train_y: torch.Size([1000])

Using single-task GP as dims_out = 1

Move all structures to GPU since device=gpu

LinAlg (Verbose) - DEBUG - Running CG on a torch.Size([1000, 11]) RHS for 1000 iterations (tol=1). Output: torch.Size([1000, 11]).
Start training
LinAlg (Verbose) - DEBUG - Running symeig on a matrix of size torch.Size([10, 11, 11]).
LinAlg (Verbose) - DEBUG - Running CG on a torch.Size([1000, 1]) RHS for 1000 iterations (tol=0.01). Output: torch.Size([1000, 1]).
LinAlg (Verbose) - DEBUG - Running CG on a torch.Size([1000, 374]) RHS for 1000 iterations (tol=0.01). Output: torch.Size([1000, 374]).
  Finished training iteration 1/1
Finished training

Start testing
  Testing: Finished evaluation
  Testing: Finished likelihood
Finished testing

System information

Additional context

I was profiling GPyTorch the other day (using cProfiler), and noticed, that in the multi-task GPs the linalg solver called by GPyTorch was torch._C._linalg.linalg_eigh. That lead to the investigation above. If you are interested in that, I can also provide the profiling information. I am using a jupyter notebook inside VSCode and a mamba environment created by the following yaml file, using micromamba create -f .\{file_name}.yml which resulted in the specs below, using micromamba env export > {other_file_name}.yml.

name: gpytorch_mwe
channels:
  - conda-forge
  - pytorch
  - nvidia
  - gpytorch
dependencies:
  - python
  - pytorch 
  - torchvision
  - torchaudio
  - pytorch-cuda==12.1
  - gpytorch::gpytorch
  - pandas
  - matplotlib
  - ipykernel
name: gpytorch_mwe
channels:
- conda-forge
- gpytorch
- nvidia
- pytorch
dependencies:
- asttokens=2.4.1=pyhd8ed1ab_0
- blas=1.0=mkl
- brotli=1.1.0=hcfcfb64_1
- brotli-bin=1.1.0=hcfcfb64_1
- brotli-python=1.1.0=py312h53d5487_1
- bzip2=1.0.8=hcfcfb64_5
- ca-certificates=2024.2.2=h56e8100_0
- cccl=2.3.1=h84bb9a4_0
- certifi=2024.2.2=pyhd8ed1ab_0
- charset-normalizer=3.3.2=pyhd8ed1ab_0
- colorama=0.4.6=pyhd8ed1ab_0
- comm=0.2.2=pyhd8ed1ab_0
- contourpy=1.2.1=py312h0d7def4_0
- cuda-cccl=12.4.127=h57928b3_1
- cuda-cccl_win-64=12.4.127=h57928b3_1
- cuda-cudart=12.1.105=0
- cuda-cudart-dev=12.1.105=0
- cuda-cupti=12.1.105=0
- cuda-libraries=12.1.0=0
- cuda-libraries-dev=12.1.0=0
- cuda-nvrtc=12.1.105=0
- cuda-nvrtc-dev=12.1.105=0
- cuda-nvtx=12.1.105=0
- cuda-opencl=12.4.127=h63175ca_0
- cuda-opencl-dev=12.4.127=h63175ca_0
- cuda-profiler-api=12.4.127=h57928b3_1
- cuda-runtime=12.1.0=0
- cuda-version=12.4=h3060b56_3
- cycler=0.12.1=pyhd8ed1ab_0
- debugpy=1.8.1=py312h53d5487_0
- decorator=5.1.1=pyhd8ed1ab_0
- exceptiongroup=1.2.0=pyhd8ed1ab_2
- executing=2.0.1=pyhd8ed1ab_0
- filelock=3.14.0=pyhd8ed1ab_0
- fonttools=4.51.0=py312he70551f_0
- freetype=2.12.1=hdaf720e_2
- gettext=0.22.5=h5728263_2
- gettext-tools=0.22.5=h7d00a51_2
- glib=2.80.0=h39d0aa6_6
- glib-tools=2.80.0=h0a98069_6
- gpytorch=1.11=0
- gst-plugins-base=1.24.1=h001b923_1
- gstreamer=1.24.1=hb4038d2_1
- icu=73.2=h63175ca_0
- idna=3.7=pyhd8ed1ab_0
- importlib-metadata=7.1.0=pyha770c72_0
- importlib_metadata=7.1.0=hd8ed1ab_0
- intel-openmp=2024.1.0=h57928b3_965
- ipykernel=6.29.3=pyha63f2e9_0
- ipython=8.22.2=pyh7428d3b_0
- jaxtyping=0.2.28=pyhd8ed1ab_0
- jedi=0.19.1=pyhd8ed1ab_0
- jinja2=3.1.3=pyhd8ed1ab_0
- joblib=1.4.0=pyhd8ed1ab_0
- jupyter_client=8.6.1=pyhd8ed1ab_0
- jupyter_core=5.7.2=py312h2e8e312_0
- khronos-opencl-icd-loader=2023.04.17=h64bf75a_0
- kiwisolver=1.4.5=py312h0d7def4_1
- krb5=1.21.2=heb0366b_0
- lcms2=2.16=h67d730c_0
- lerc=4.0.0=h63175ca_0
- libasprintf=0.22.5=h5728263_2
- libasprintf-devel=0.22.5=h5728263_2
- libblas=3.9.0=1_h8933c1f_netlib
- libbrotlicommon=1.1.0=hcfcfb64_1
- libbrotlidec=1.1.0=hcfcfb64_1
- libbrotlienc=1.1.0=hcfcfb64_1
- libcblas=3.9.0=5_hd5c7e75_netlib
- libclang13=18.1.3=default_hf64faad_0
- libcublas=12.1.0.26=0
- libcublas-dev=12.1.0.26=0
- libcufft=11.0.2.4=0
- libcufft-dev=11.0.2.4=0
- libcurand=10.3.5.147=h63175ca_1
- libcurand-dev=10.3.5.147=h63175ca_1
- libcusolver=11.4.4.55=0
- libcusolver-dev=11.4.4.55=0
- libcusparse=12.0.2.55=0
- libcusparse-dev=12.0.2.55=0
- libdeflate=1.20=hcfcfb64_0
- libexpat=2.6.2=h63175ca_0
- libffi=3.4.2=h8ffe710_5
- libgettextpo=0.22.5=h5728263_2
- libgettextpo-devel=0.22.5=h5728263_2
- libglib=2.80.0=h39d0aa6_6
- libhwloc=2.10.0=default_h2fffb23_1000
- libiconv=1.17=hcfcfb64_2
- libintl=0.22.5=h5728263_2
- libintl-devel=0.22.5=h5728263_2
- libjpeg-turbo=3.0.0=hcfcfb64_1
- liblapack=3.9.0=5_hd5c7e75_netlib
- libnpp=12.0.2.50=0
- libnpp-dev=12.0.2.50=0
- libnvjitlink=12.1.105=0
- libnvjitlink-dev=12.1.105=0
- libnvjpeg=12.1.1.14=0
- libnvjpeg-dev=12.1.1.14=0
- libogg=1.3.4=h8ffe710_1
- libpng=1.6.43=h19919ed_0
- libsodium=1.0.18=h8d14728_1
- libsqlite=3.45.3=hcfcfb64_0
- libtiff=4.6.0=hddb2be6_3
- libuv=1.48.0=hcfcfb64_0
- libvorbis=1.3.7=h0e60522_0
- libwebp-base=1.4.0=hcfcfb64_0
- libxcb=1.15=hcd874cb_0
- libxml2=2.12.6=hc3477c8_2
- libzlib=1.2.13=hcfcfb64_5
- linear_operator=0.5.2=pyhd8ed1ab_0
- m2w64-gcc-libgfortran=5.3.0=6
- m2w64-gcc-libs=5.3.0=7
- m2w64-gcc-libs-core=5.3.0=7
- m2w64-gmp=6.1.0=2
- m2w64-libwinpthread-git=5.0.0.4634.697f757=2
- markupsafe=2.1.5=py312he70551f_0
- matplotlib=3.8.4=py312h2e8e312_0
- matplotlib-base=3.8.4=py312h26ecaf7_0
- matplotlib-inline=0.1.7=pyhd8ed1ab_0
- mkl=2023.1.0=h6a75c08_48682
- mpmath=1.3.0=pyhd8ed1ab_0
- msys2-conda-epoch=20160418=1
- munkres=1.1.4=pyh9f0ad1d_0
- nest-asyncio=1.6.0=pyhd8ed1ab_0
- networkx=3.3=pyhd8ed1ab_1
- numpy=1.26.4=py312h8753938_0
- openjpeg=2.5.2=h3d672ee_0
- openssl=3.3.0=hcfcfb64_0
- packaging=24.0=pyhd8ed1ab_0
- pandas=2.2.2=py312h2ab9e98_0
- parso=0.8.4=pyhd8ed1ab_0
- pcre2=10.43=h17e33f8_0
- pickleshare=0.7.5=py_1003
- pillow=10.3.0=py312h6f6a607_0
- pip=24.0=pyhd8ed1ab_0
- platformdirs=4.2.1=pyhd8ed1ab_0
- ply=3.11=pyhd8ed1ab_2
- prompt-toolkit=3.0.42=pyha770c72_0
- psutil=5.9.8=py312he70551f_0
- pthread-stubs=0.4=hcd874cb_1001
- pthreads-win32=2.9.1=hfa6e2cd_3
- pure_eval=0.2.2=pyhd8ed1ab_0
- pygments=2.17.2=pyhd8ed1ab_0
- pyparsing=3.1.2=pyhd8ed1ab_0
- pyqt=5.15.9=py312he09f080_5
- pyqt5-sip=12.12.2=py312h53d5487_5
- pysocks=1.7.1=pyh0701188_6
- python=3.12.3=h2628c8c_0_cpython
- python-dateutil=2.9.0=pyhd8ed1ab_0
- python-tzdata=2024.1=pyhd8ed1ab_0
- python_abi=3.12=4_cp312
- pytorch=2.3.0=py3.12_cuda12.1_cudnn8_0
- pytorch-cuda=12.1=hde6ce7c_5
- pytorch-mutex=1.0=cuda
- pytz=2024.1=pyhd8ed1ab_0
- pywin32=306=py312h53d5487_2
- pyyaml=6.0.1=py312he70551f_1
- pyzmq=26.0.2=py312hd7027bb_0
- qt-main=5.15.8=hcef0176_21
- requests=2.31.0=pyhd8ed1ab_0
- scikit-learn=1.4.2=py312hcacafb1_0
- scipy=1.13.0=py312h8753938_0
- setuptools=69.5.1=pyhd8ed1ab_0
- sip=6.7.12=py312h53d5487_0
- six=1.16.0=pyh6c4a22f_0
- stack_data=0.6.2=pyhd8ed1ab_0
- sympy=1.12=pyh04b8f61_3
- tbb=2021.12.0=h91493d7_0
- threadpoolctl=3.5.0=pyhc1e730c_0
- tk=8.6.13=h5226925_1
- toml=0.10.2=pyhd8ed1ab_0
- tomli=2.0.1=pyhd8ed1ab_0
- torchaudio=2.3.0=py312_cu121
- torchvision=0.18.0=py312_cu121
- tornado=6.4=py312he70551f_0
- traitlets=5.14.3=pyhd8ed1ab_0
- typeguard=2.13.3=pyhd8ed1ab_0
- typing-extensions=4.11.0=hd8ed1ab_0
- typing_extensions=4.11.0=pyha770c72_0
- tzdata=2024a=h0c530f3_0
- ucrt=10.0.22621.0=h57928b3_0
- urllib3=2.2.1=pyhd8ed1ab_0
- vc=14.3=hcf57466_18
- vc14_runtime=14.38.33130=h82b7239_18
- vs2015_runtime=14.38.33130=hcb4865c_18
- wcwidth=0.2.13=pyhd8ed1ab_0
- wheel=0.43.0=pyhd8ed1ab_1
- win_inet_pton=1.1.0=pyhd8ed1ab_6
- xorg-libxau=1.0.11=hcd874cb_0
- xorg-libxdmcp=1.1.3=hcd874cb_0
- xz=5.2.6=h8d14728_0
- yaml=0.2.5=h8ffe710_2
- zeromq=4.3.5=h63175ca_1
- zipp=3.17.0=pyhd8ed1ab_0
- zstd=1.5.5=h12be248_0