FalkonML / falkon

Large-scale, multi-GPU capable, kernel solver
https://falkonml.github.io/falkon/
MIT License
181 stars 22 forks source link

Fix GaussianKernel in multi-sigma case #20

Closed Akatsuki96 closed 4 years ago

Akatsuki96 commented 4 years ago

It is extremely slow. Example:

#!/usr/bin/env python3
from falkon import Falkon
from falkon.kernels import GaussianKernel
from falkon.options import FalkonOptions
import numpy as np
import time
import torch

def build_dataset():
    X = torch.rand(10000,28)
    f = lambda x: torch.sin(x)
    Y = f(X)
    return X,Y

def single_sigma():
    sigma = 2.8
    lam = 1e-5
    ITERS = 10
    SEED= 4242
    config = {
        'kernel': GaussianKernel(sigma=sigma),
        'penalty': lam,
        'M': 200,
        'maxiter': ITERS,
        'seed': SEED,
        'options': FalkonOptions()
    }
    return Falkon(**config)

def multi_sigma():
    sigma = torch.tensor([2.8 for _ in range(28)])
    lam = 1e-5
    ITERS = 10
    SEED= 4242
    config = {
        'kernel': GaussianKernel(sigma=sigma),
        'penalty': lam,
        'M': 200,
        'maxiter': ITERS,
        'seed': SEED,
        'options': FalkonOptions()
    }

    return Falkon(**config)

def multi_sigma_matrix():
    sigma = 2.8 * torch.eye(28,28) 
    lam = 1e-5
    ITERS = 10
    SEED= 4242
    config = {
        'kernel': GaussianKernel(sigma=sigma),
        'penalty': lam,
        'M': 200,
        'maxiter': ITERS,
        'seed': SEED,
        'options': FalkonOptions()
    }

    return Falkon(**config)

def test_fit(X,Y, flk):
    st = time.time()
    flk.fit(X, Y)
    end = time.time()
    return end - st

X, Y = build_dataset()
print("[->] Single sigma => dataset fitted in {} seconds".format(test_fit(X, Y, single_sigma())))
print("[->] Multi sigma => dataset fitted in {} seconds".format(test_fit(X, Y, multi_sigma())))
print("[->] Multi sigma (using a matrix with sigmas in the diagonal) => dataset fitted in {} seconds".format(test_fit(X, Y, multi_sigma_matrix()))) 
Giodiro commented 4 years ago

The issue is that we treat diagonal length-scales the same as a full matrix, and this breaks KeOps at a much lower data-dimension than with a single lengthscale. I'll implement two fixes:

  1. Rework the Gaussian kernel so that it uses division by the lengthscale instead of the full-matrix multiplication in case a vector of length-scales is provided. This is a common use case!
  2. Switch to the matmul implementation earlier when full-matrix lengthscales are used.

In the meantime you can try to use the option use_keops="no" for full-matrix lengthscales. This should speed things up.