pyRiemann / pyRiemann

Machine learning for multivariate data through the Riemannian geometry of positive definite matrices in Python
https://pyriemann.readthedocs.io
BSD 3-Clause "New" or "Revised" License
640 stars 164 forks source link

Add TSA to transfer learning module #319

Open qbarthelemy opened 2 months ago

qbarthelemy commented 2 months ago

In the transfer learning module, add Tangent Space Alignement (TSA), see https://www.frontiersin.org/articles/10.3389/fnhum.2022.1049985/pdf

qbarthelemy commented 2 months ago

Code shared by first author

import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.decomposition import PCA
from pyriemann.estimation import Covariances
from pyriemann.transfer import decode_domains
from sklearn.preprocessing import LabelEncoder

class TSAxPCA(BaseEstimator, TransformerMixin):
    def __init__(self, target_domain, eVar=0.999, n_clusters=3, n_components=1):
        self.target_domain = target_domain
        self.eVar = eVar
        self.source_means = None
        self.target_means = None
        self.u = None
        self.vh = None

        self.incomplete_source = None
        self.incomplete_target = None

        self.n_clusters = n_clusters
        self.n_components = n_components

        self.name = "TSAxPCA"

    def fit(self, X, y_enc):
        X, y, domains = decode_domains(X, y_enc)
        X_source = X[domains != self.target_domain]
        y_source = y[domains != self.target_domain]
        X_target = X[domains == self.target_domain]
        y_target = y[domains == self.target_domain]
        self.fit_source(X_source, y_source)
        self.fit_target(X_target, y_target)

    def fit_source(self, X, y):
        y = LabelEncoder().fit_transform(y=y)
        self.incomplete_source = False
        self.source_pcas = [PCA(n_components=(self.n_components if self.n_components != "max" else X.shape[1])) for i in range(len(np.unique(y)))]
        self.source_means = np.array([np.ndarray.mean(X[y == i, :], axis=0) for i in np.unique(y)])
        for i in np.unique(y):
            Xi = X[y == i]
            ni = len(Xi)
            if ni < self.n_clusters:
                self.incomplete_source = True
                break
            self.source_pcas[i].fit(Xi)
            X_pca = self.source_pcas[i].transform(Xi)
            for j in range(self.n_components):
                args = np.argsort(X_pca[:, j], axis=0)
                self.source_means = np.vstack([self.source_means,
                                               [np.ndarray.mean(Xi[args[round(k*ni/self.n_clusters):round((k+1)*ni/self.n_clusters)]], axis=0) for k in range(self.n_clusters)]])
        if self.target_means is not None:
            if self.incomplete_source or self.incomplete_target:
                print("There is not enough point for either target or source. Compute basic TSA rotation...")
                self.u, self.vh = self.compute_incomplete_rotation(len(np.unique(y)))
            else:
                self.u, self.vh = self.compute_rotations()
        return self

    def compute_rotations(self):
        c12 = np.dot(self.source_means.T, self.target_means)
        u, s, vh = np.linalg.svd(c12)
        if self.eVar < 2:
            n_features = np.ndarray.sum(np.ndarray.cumsum(s) < self.eVar * np.ndarray.sum(s)) + 1
        else:
            n_features = self.eVar
        u = u[:, :n_features]
        vh = vh[:n_features, :]
        return u, vh

    def compute_incomplete_rotation(self, n_classes):
        c12 = np.dot(self.source_means[:n_classes].T, self.target_means[:n_classes])
        u, s, vh = np.linalg.svd(c12)
        if self.eVar < 2:
            n_features = np.ndarray.sum(np.ndarray.cumsum(s) < self.eVar * np.ndarray.sum(s)) + 1
        else:
            n_features = self.eVar
        u = u[:, :n_features]
        vh = vh[:n_features, :]
        return u, vh

    def fit_target(self, X, y):
        y = LabelEncoder().fit_transform(y=y)
        self.incomplete_target = False
        self.target_means = np.array([np.ndarray.mean(X[y == i, :], axis=0) for i in np.unique(y)])
        for i in np.unique(y):
            Xi = X[y == i]
            ni = len(Xi)
            if ni < self.n_clusters:
                self.incomplete_target = True
                break
            X_pca = self.source_pcas[i].transform(Xi)
            for j in range(self.n_components):
                args = np.argsort(X_pca[:, j], axis=0)
                self.target_means = np.vstack([self.target_means,
                                               [np.ndarray.mean(Xi[args[round(k*ni/self.n_clusters):round((k+1)*ni/self.n_clusters)]], axis=0) for k in range(self.n_clusters)]])
        if self.source_means is not None:
            if self.incomplete_source or self.incomplete_target:
                print("There is not enough point for either target or source. Compute basic TSA rotation...")
                self.u, self.vh = self.compute_incomplete_rotation(len(np.unique(y)))
            else:
                self.u, self.vh = self.compute_rotations()
        return self

    def transform(self, X, y=None):
        return np.dot(np.dot(X, self.vh.T), self.u.T)

    def transform_source(self, X, y=None):
        return np.dot(np.dot(X, self.u), self.u.T)

class TSAxPCA_generalized(BaseEstimator, TransformerMixin):
    def __init__(self, target_domain, eVar=0.999, n_clusters=3, n_components=1, equal_dims=False):
        self.target_domain = target_domain
        self.eVar = eVar
        self.n_clusters = n_clusters
        self.n_components = n_components
        self.equal_dims = equal_dims
        self.source_means = None
        self.target_means = None
        self.u = None
        self.vh = None
        self.source_eig = []
        self.source_vals = []
        self.target_eig = []
        self.target_vals = []
        self.source_pcas = []

        self.incomplete_source = None
        self.incomplete_target = None

        self.name = "TSAxPCA_generalized"

    def fit(self, X, y_enc):
        X, y, domains = decode_domains(X, y_enc)
        X_source = X[domains != self.target_domain]
        y_source = y[domains != self.target_domain]
        X_target = X[domains == self.target_domain]
        y_target = y[domains == self.target_domain]
        self.fit_source(X_source, y_source)
        self.fit_target(X_target, y_target)

    def fit_source(self, X, y):
        self.incomplete_source = False
        self.source_pcas = [PCA(n_components=(self.n_components if self.n_components != "max" else X.shape[1])) for i in range(len(np.unique(y)))]
        self.source_means = np.array([np.ndarray.mean(X[y == yi, :], axis=0) for yi in np.unique(y)])
        for i_class, yi in enumerate(np.unique(y)):
            Xi = X[y == yi]
            ni = len(Xi)
            cov = Covariances(estimator='oas').fit_transform(Xi.T.reshape(1, Xi.shape[1], Xi.shape[0]))
            vals, eig = np.linalg.eigh(cov[0])
            arg = np.argsort(np.abs(vals))[::-1]
            self.source_vals.append(vals[arg])
            self.source_eig.append(eig[:, arg])
            if not self.equal_dims:
                Xi_pca = np.dot(Xi, self.source_eig[i_class])
                for i, vect in enumerate(self.source_eig[i_class].T[:self.n_components]):
                    sign = np.sign(np.dot(vect, np.ones(len(vect))))
                    args = np.argsort(sign*Xi_pca[:, i], axis=0)
                    self.source_means = np.vstack([self.source_means,
                                                   [np.ndarray.mean(Xi[args[round(k * ni / self.n_clusters):round(
                                                       (k + 1) * ni / self.n_clusters)]], axis=0) for k in
                                                    range(self.n_clusters)]])
            else:
                Xi = X[y == yi]
                ni = len(Xi)
                if ni < self.n_clusters:
                    self.incomplete_target = True
                    break
                self.source_pcas[i_class].fit(Xi)
                X_pca = self.source_pcas[i_class].transform(Xi)
                for j in range(self.n_components):
                    args = np.argsort(X_pca[:, j], axis=0)
                    self.source_means = np.vstack([self.source_means,
                                                   [np.ndarray.mean(Xi[args[round(k * ni / self.n_clusters):round(
                                                       (k + 1) * ni / self.n_clusters)]], axis=0) for k in
                                                    range(self.n_clusters)]])
        if self.target_means is not None:
            self.u, self.vh = self.compute_rotations()
        return self

    def compute_rotations(self):
        c12 = np.dot(self.source_means.T, self.target_means)
        u, s, vh = np.linalg.svd(c12)
        if self.eVar < 2:
            n_features = np.min([np.ndarray.sum(np.ndarray.cumsum(s) < self.eVar * np.ndarray.sum(s))+1, len(u), len(vh)])
        else:
            n_features = np.min([self.eVar, len(u), len(vh)])
        u = u[:, :n_features]
        vh = vh[:n_features, :]
        return u, vh

    def compute_incomplete_rotation(self, n_classes):
        c12 = np.dot(self.source_means[:n_classes].T, self.target_means[:n_classes])
        u, s, vh = np.linalg.svd(c12)
        if self.eVar < 2:
            n_features = np.min([np.ndarray.sum(np.ndarray.cumsum(s) < self.eVar * np.ndarray.sum(s))+1, len(u), len(vh)])
        else:
            n_features = self.eVar
        u = u[:, :n_features]
        vh = vh[:n_features, :]
        return u, vh

    def fit_target(self, X, y):
        self.incomplete_target = False
        self.target_means = np.array([np.ndarray.mean(X[y == i, :], axis=0) for i in np.unique(y)])
        for i_class, yi in enumerate(np.unique(y)):
            Xi = X[y == yi]
            ni = len(Xi)
            cov = Covariances(estimator='oas').fit_transform(Xi.T.reshape(1, Xi.shape[1], Xi.shape[0]))
            vals, eig = np.linalg.eigh(cov[0])
            arg = np.argsort(np.abs(vals))[::-1]
            self.target_vals.append(vals[arg])
            self.target_eig.append(eig[:, arg])
            if not self.equal_dims:
                Xi_pca = np.dot(Xi, self.target_eig[i_class])
                for i, vect in enumerate(self.target_eig[i_class].T[:self.n_components]):
                    sign = np.sign(np.dot(vect, np.ones(len(vect))))
                    args = np.argsort(sign * Xi_pca[:, i], axis=0)
                    self.target_means = np.vstack([self.target_means,
                                                   [np.ndarray.mean(Xi[args[round(k * ni / self.n_clusters):round(
                                                       (k + 1) * ni / self.n_clusters)]], axis=0) for k in
                                                    range(self.n_clusters)]])
            elif len(self.source_eig) == 0:
                raise(ValueError("The source is not fitted yet. Please use the 'fit' function on Source data before fitting target data."))
            else:
                Xi = X[y == yi]
                ni = len(Xi)
                if ni < self.n_clusters:
                    self.incomplete_target = True
                    break
                X_pca = self.source_pcas[i_class].transform(Xi)
                for j in range(self.n_components):
                    args = np.argsort(X_pca[:, j], axis=0)
                    self.target_means = np.vstack([self.target_means,
                                                   [np.ndarray.mean(Xi[args[round(k * ni / self.n_clusters):round(
                                                       (k + 1) * ni / self.n_clusters)]], axis=0) for k in
                                                    range(self.n_clusters)]])
        if self.source_means is not None:
            if self.incomplete_source or self.incomplete_target:
                print("There is not enough point for either target or source. Compute basic TSA rotation.")
                self.u, self.vh = self.compute_incomplete_rotation(len(np.unique(y)))
            else:
                self.u, self.vh = self.compute_rotations()
        return self

    def transform(self, X, y=None):
        return np.dot(np.dot(X, self.vh.T), self.u.T)

    def transform_source(self, X, y=None):
        return np.dot(np.dot(X, self.u), self.u.T)

def rot_data(X, y_enc, **kwargs):
    X, y, domains = decode_domains(X, y_enc)
    Xs = X[domains != self.target_domain]
    ys = y[domains != self.target_domain]
    Xt = X[domains == self.target_domain]
    yt = y[domains == self.target_domain]
    tsa = TSAxPCA(**kwargs)
    tsa.fit(Xs, ys)
    tsa.fit_target(Xt, yt)
    Xt_rot, _, _ = tsa.transform(Xt)
    return Xt_rot

def align_data(X, y_enc, **kwargs):
    X, y, domains = decode_domains(X, y_enc)
    Xs = X[domains != self.target_domain]
    ys = y[domains != self.target_domain]
    Xt = X[domains == self.target_domain]
    yt = y[domains == self.target_domain]
    Xt_rec = Xt - np.mean(Xt, axis=0) + np.mean(Xs, axis=0)
    Xt_rot = rot_data(Xs, ys, Xt_rec, yt, **kwargs)
    Xt_sca = Xt_rot / np.mean(np.linalg.norm(Xt_rot, axis=1), axis=0) * np.mean(np.linalg.norm(Xs, axis=1), axis=0)
    return Xt_sca

def align_test(Xs, ys, Xt_train, yt_train, Xt_test, yt_test=None, **kwargs):
    Xt_rec_train = Xt_train - np.mean(Xt_train, axis=0) + np.mean(Xs, axis=0)
    Xt_rec_test = Xt_test - np.mean(Xt_train, axis=0) + np.mean(Xs, axis=0)
    tsa = TSAxPCA(**kwargs)
    tsa.fit(Xs, ys)
    tsa.fit_target(Xt_rec_train, yt_train)
    Xt_rot_train, _, _ = tsa.transform(Xt_rec_train)
    Xt_rot_test, _, _ = tsa.transform(Xt_rec_test)

    Xt_sca_test = Xt_rot_test / np.mean(np.linalg.norm(Xt_rot_train, axis=1), axis=0) * np.mean(np.linalg.norm(Xs, axis=1), axis=0)
    return Xt_sca_test
qbarthelemy commented 2 months ago

Example

import matplotlib.pyplot as plt
import numpy as np
from pyriemann.tangentspace import tangent_space, untangent_space
from pyriemann.utils.mean import mean_riemann, mean_logeuclid
from pyriemann.estimation import Covariances
from pyriemann.transfer import encode_domains

from _tsa import TSAxPCA

n_classes = 2
n_clusters = 500
n_dims = 2

source_label = "source"
target_label = "target"

n_trials_source = n_classes*n_clusters
n_trials_target = n_trials_source
fake_mean = 0.5
fake_std = 0.5
theta = np.pi/5
fake_rot = np.array([[np.cos(theta), -np.sin(theta)],
                     [np.sin(theta),  np.cos(theta)]])

############# ADDED FALSE SOURCE ##################

Xs_means = 5 * np.array([[-1 / np.sqrt(n_classes * (n_classes - 1)) for _ in range(0, i)] +
                     [np.sqrt((n_classes - 1) / n_classes)] +
                     [-1 / np.sqrt(n_classes * (n_classes - 1)) for _ in range(i + 1, n_classes)] +
                     [0 for _ in range(n_dims - n_classes)] for i in range(n_classes)])
Xs_ts_list = [Xs_means[0] + (-(n_clusters - 1) + 2 * i) / (n_clusters - 1) * np.array(
    [0, 5] + [0 for _ in range(n_dims - n_classes)]) for i in
              range(n_clusters)] + \
             [Xs_means[1] + (-(n_clusters - 1) + 2 * i) / (n_clusters - 1) * np.array(
                 [5, 0] + [0 for _ in range(n_dims - n_classes)]) for i in
              range(n_clusters)]
ys = np.array([0 for _ in range(n_clusters)] + [1 for _ in range(n_clusters)])
Xs_ts = np.array(Xs_ts_list)
Xs_ts = Xs_ts + np.random.randn(n_clusters*n_classes, n_dims)

mean_source = np.mean(Xs_ts, axis=0)
Xs_ts = Xs_ts - mean_source

std_source = np.mean(np.linalg.norm(Xs_ts, axis=1), axis=0)
Xs_ts = Xs_ts / std_source

###### Create the target as a modification of the source data #############
yt = ys
# Xt_ts = np.multiply(1+5*np.array([[np.cos((i+1)*np.pi/2), np.sin((i+1)*np.pi/2), 0] for i in yt]), (np.random.normal(scale=0.15, size=(n_trials_target, 3)) + 0.4 * np.array([[i - 0.5, i - 0.5, 0] for i in yt])))
Xt_ts = Xs_ts
Xt_ts = np.dot(Xt_ts, fake_rot)

X = np.append(Xs_ts, Xt_ts, axis=0)
y = np.append(ys, yt)
yDom = np.append(np.repeat(source_label, n_trials_source), np.repeat(target_label, n_trials_target))

X_enc, y_enc = encode_domains(X, y, yDom)

tsa = TSAxPCA(target_label, n_clusters=5, eVar=1.0)

tsa.fit(X_enc, y_enc=y_enc)
Xt_ts_aligned = tsa.transform(X[yDom == target_label])

plt.figure()
plt.title("Source")
plt.scatter(Xs_ts[:, 0], Xs_ts[:, 1])
plt.xlim([-2, 2])
plt.ylim([-2, 2])

plt.figure()
plt.title("Target")
plt.scatter(Xt_ts[:, 0], Xt_ts[:, 1])
plt.xlim([-2, 2])
plt.ylim([-2, 2])

plt.figure()
plt.title("Target aligned")
plt.scatter(Xt_ts_aligned[:, 0], Xt_ts_aligned[:, 1])
plt.xlim([-2, 2])
plt.ylim([-2, 2])

plt.show()
agramfort commented 2 months ago

cc @apmellot @antoinecollas