Open qbarthelemy opened 2 months ago
Code shared by first author
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.decomposition import PCA
from pyriemann.estimation import Covariances
from pyriemann.transfer import decode_domains
from sklearn.preprocessing import LabelEncoder
class TSAxPCA(BaseEstimator, TransformerMixin):
def __init__(self, target_domain, eVar=0.999, n_clusters=3, n_components=1):
self.target_domain = target_domain
self.eVar = eVar
self.source_means = None
self.target_means = None
self.u = None
self.vh = None
self.incomplete_source = None
self.incomplete_target = None
self.n_clusters = n_clusters
self.n_components = n_components
self.name = "TSAxPCA"
def fit(self, X, y_enc):
X, y, domains = decode_domains(X, y_enc)
X_source = X[domains != self.target_domain]
y_source = y[domains != self.target_domain]
X_target = X[domains == self.target_domain]
y_target = y[domains == self.target_domain]
self.fit_source(X_source, y_source)
self.fit_target(X_target, y_target)
def fit_source(self, X, y):
y = LabelEncoder().fit_transform(y=y)
self.incomplete_source = False
self.source_pcas = [PCA(n_components=(self.n_components if self.n_components != "max" else X.shape[1])) for i in range(len(np.unique(y)))]
self.source_means = np.array([np.ndarray.mean(X[y == i, :], axis=0) for i in np.unique(y)])
for i in np.unique(y):
Xi = X[y == i]
ni = len(Xi)
if ni < self.n_clusters:
self.incomplete_source = True
break
self.source_pcas[i].fit(Xi)
X_pca = self.source_pcas[i].transform(Xi)
for j in range(self.n_components):
args = np.argsort(X_pca[:, j], axis=0)
self.source_means = np.vstack([self.source_means,
[np.ndarray.mean(Xi[args[round(k*ni/self.n_clusters):round((k+1)*ni/self.n_clusters)]], axis=0) for k in range(self.n_clusters)]])
if self.target_means is not None:
if self.incomplete_source or self.incomplete_target:
print("There is not enough point for either target or source. Compute basic TSA rotation...")
self.u, self.vh = self.compute_incomplete_rotation(len(np.unique(y)))
else:
self.u, self.vh = self.compute_rotations()
return self
def compute_rotations(self):
c12 = np.dot(self.source_means.T, self.target_means)
u, s, vh = np.linalg.svd(c12)
if self.eVar < 2:
n_features = np.ndarray.sum(np.ndarray.cumsum(s) < self.eVar * np.ndarray.sum(s)) + 1
else:
n_features = self.eVar
u = u[:, :n_features]
vh = vh[:n_features, :]
return u, vh
def compute_incomplete_rotation(self, n_classes):
c12 = np.dot(self.source_means[:n_classes].T, self.target_means[:n_classes])
u, s, vh = np.linalg.svd(c12)
if self.eVar < 2:
n_features = np.ndarray.sum(np.ndarray.cumsum(s) < self.eVar * np.ndarray.sum(s)) + 1
else:
n_features = self.eVar
u = u[:, :n_features]
vh = vh[:n_features, :]
return u, vh
def fit_target(self, X, y):
y = LabelEncoder().fit_transform(y=y)
self.incomplete_target = False
self.target_means = np.array([np.ndarray.mean(X[y == i, :], axis=0) for i in np.unique(y)])
for i in np.unique(y):
Xi = X[y == i]
ni = len(Xi)
if ni < self.n_clusters:
self.incomplete_target = True
break
X_pca = self.source_pcas[i].transform(Xi)
for j in range(self.n_components):
args = np.argsort(X_pca[:, j], axis=0)
self.target_means = np.vstack([self.target_means,
[np.ndarray.mean(Xi[args[round(k*ni/self.n_clusters):round((k+1)*ni/self.n_clusters)]], axis=0) for k in range(self.n_clusters)]])
if self.source_means is not None:
if self.incomplete_source or self.incomplete_target:
print("There is not enough point for either target or source. Compute basic TSA rotation...")
self.u, self.vh = self.compute_incomplete_rotation(len(np.unique(y)))
else:
self.u, self.vh = self.compute_rotations()
return self
def transform(self, X, y=None):
return np.dot(np.dot(X, self.vh.T), self.u.T)
def transform_source(self, X, y=None):
return np.dot(np.dot(X, self.u), self.u.T)
class TSAxPCA_generalized(BaseEstimator, TransformerMixin):
def __init__(self, target_domain, eVar=0.999, n_clusters=3, n_components=1, equal_dims=False):
self.target_domain = target_domain
self.eVar = eVar
self.n_clusters = n_clusters
self.n_components = n_components
self.equal_dims = equal_dims
self.source_means = None
self.target_means = None
self.u = None
self.vh = None
self.source_eig = []
self.source_vals = []
self.target_eig = []
self.target_vals = []
self.source_pcas = []
self.incomplete_source = None
self.incomplete_target = None
self.name = "TSAxPCA_generalized"
def fit(self, X, y_enc):
X, y, domains = decode_domains(X, y_enc)
X_source = X[domains != self.target_domain]
y_source = y[domains != self.target_domain]
X_target = X[domains == self.target_domain]
y_target = y[domains == self.target_domain]
self.fit_source(X_source, y_source)
self.fit_target(X_target, y_target)
def fit_source(self, X, y):
self.incomplete_source = False
self.source_pcas = [PCA(n_components=(self.n_components if self.n_components != "max" else X.shape[1])) for i in range(len(np.unique(y)))]
self.source_means = np.array([np.ndarray.mean(X[y == yi, :], axis=0) for yi in np.unique(y)])
for i_class, yi in enumerate(np.unique(y)):
Xi = X[y == yi]
ni = len(Xi)
cov = Covariances(estimator='oas').fit_transform(Xi.T.reshape(1, Xi.shape[1], Xi.shape[0]))
vals, eig = np.linalg.eigh(cov[0])
arg = np.argsort(np.abs(vals))[::-1]
self.source_vals.append(vals[arg])
self.source_eig.append(eig[:, arg])
if not self.equal_dims:
Xi_pca = np.dot(Xi, self.source_eig[i_class])
for i, vect in enumerate(self.source_eig[i_class].T[:self.n_components]):
sign = np.sign(np.dot(vect, np.ones(len(vect))))
args = np.argsort(sign*Xi_pca[:, i], axis=0)
self.source_means = np.vstack([self.source_means,
[np.ndarray.mean(Xi[args[round(k * ni / self.n_clusters):round(
(k + 1) * ni / self.n_clusters)]], axis=0) for k in
range(self.n_clusters)]])
else:
Xi = X[y == yi]
ni = len(Xi)
if ni < self.n_clusters:
self.incomplete_target = True
break
self.source_pcas[i_class].fit(Xi)
X_pca = self.source_pcas[i_class].transform(Xi)
for j in range(self.n_components):
args = np.argsort(X_pca[:, j], axis=0)
self.source_means = np.vstack([self.source_means,
[np.ndarray.mean(Xi[args[round(k * ni / self.n_clusters):round(
(k + 1) * ni / self.n_clusters)]], axis=0) for k in
range(self.n_clusters)]])
if self.target_means is not None:
self.u, self.vh = self.compute_rotations()
return self
def compute_rotations(self):
c12 = np.dot(self.source_means.T, self.target_means)
u, s, vh = np.linalg.svd(c12)
if self.eVar < 2:
n_features = np.min([np.ndarray.sum(np.ndarray.cumsum(s) < self.eVar * np.ndarray.sum(s))+1, len(u), len(vh)])
else:
n_features = np.min([self.eVar, len(u), len(vh)])
u = u[:, :n_features]
vh = vh[:n_features, :]
return u, vh
def compute_incomplete_rotation(self, n_classes):
c12 = np.dot(self.source_means[:n_classes].T, self.target_means[:n_classes])
u, s, vh = np.linalg.svd(c12)
if self.eVar < 2:
n_features = np.min([np.ndarray.sum(np.ndarray.cumsum(s) < self.eVar * np.ndarray.sum(s))+1, len(u), len(vh)])
else:
n_features = self.eVar
u = u[:, :n_features]
vh = vh[:n_features, :]
return u, vh
def fit_target(self, X, y):
self.incomplete_target = False
self.target_means = np.array([np.ndarray.mean(X[y == i, :], axis=0) for i in np.unique(y)])
for i_class, yi in enumerate(np.unique(y)):
Xi = X[y == yi]
ni = len(Xi)
cov = Covariances(estimator='oas').fit_transform(Xi.T.reshape(1, Xi.shape[1], Xi.shape[0]))
vals, eig = np.linalg.eigh(cov[0])
arg = np.argsort(np.abs(vals))[::-1]
self.target_vals.append(vals[arg])
self.target_eig.append(eig[:, arg])
if not self.equal_dims:
Xi_pca = np.dot(Xi, self.target_eig[i_class])
for i, vect in enumerate(self.target_eig[i_class].T[:self.n_components]):
sign = np.sign(np.dot(vect, np.ones(len(vect))))
args = np.argsort(sign * Xi_pca[:, i], axis=0)
self.target_means = np.vstack([self.target_means,
[np.ndarray.mean(Xi[args[round(k * ni / self.n_clusters):round(
(k + 1) * ni / self.n_clusters)]], axis=0) for k in
range(self.n_clusters)]])
elif len(self.source_eig) == 0:
raise(ValueError("The source is not fitted yet. Please use the 'fit' function on Source data before fitting target data."))
else:
Xi = X[y == yi]
ni = len(Xi)
if ni < self.n_clusters:
self.incomplete_target = True
break
X_pca = self.source_pcas[i_class].transform(Xi)
for j in range(self.n_components):
args = np.argsort(X_pca[:, j], axis=0)
self.target_means = np.vstack([self.target_means,
[np.ndarray.mean(Xi[args[round(k * ni / self.n_clusters):round(
(k + 1) * ni / self.n_clusters)]], axis=0) for k in
range(self.n_clusters)]])
if self.source_means is not None:
if self.incomplete_source or self.incomplete_target:
print("There is not enough point for either target or source. Compute basic TSA rotation.")
self.u, self.vh = self.compute_incomplete_rotation(len(np.unique(y)))
else:
self.u, self.vh = self.compute_rotations()
return self
def transform(self, X, y=None):
return np.dot(np.dot(X, self.vh.T), self.u.T)
def transform_source(self, X, y=None):
return np.dot(np.dot(X, self.u), self.u.T)
def rot_data(X, y_enc, **kwargs):
X, y, domains = decode_domains(X, y_enc)
Xs = X[domains != self.target_domain]
ys = y[domains != self.target_domain]
Xt = X[domains == self.target_domain]
yt = y[domains == self.target_domain]
tsa = TSAxPCA(**kwargs)
tsa.fit(Xs, ys)
tsa.fit_target(Xt, yt)
Xt_rot, _, _ = tsa.transform(Xt)
return Xt_rot
def align_data(X, y_enc, **kwargs):
X, y, domains = decode_domains(X, y_enc)
Xs = X[domains != self.target_domain]
ys = y[domains != self.target_domain]
Xt = X[domains == self.target_domain]
yt = y[domains == self.target_domain]
Xt_rec = Xt - np.mean(Xt, axis=0) + np.mean(Xs, axis=0)
Xt_rot = rot_data(Xs, ys, Xt_rec, yt, **kwargs)
Xt_sca = Xt_rot / np.mean(np.linalg.norm(Xt_rot, axis=1), axis=0) * np.mean(np.linalg.norm(Xs, axis=1), axis=0)
return Xt_sca
def align_test(Xs, ys, Xt_train, yt_train, Xt_test, yt_test=None, **kwargs):
Xt_rec_train = Xt_train - np.mean(Xt_train, axis=0) + np.mean(Xs, axis=0)
Xt_rec_test = Xt_test - np.mean(Xt_train, axis=0) + np.mean(Xs, axis=0)
tsa = TSAxPCA(**kwargs)
tsa.fit(Xs, ys)
tsa.fit_target(Xt_rec_train, yt_train)
Xt_rot_train, _, _ = tsa.transform(Xt_rec_train)
Xt_rot_test, _, _ = tsa.transform(Xt_rec_test)
Xt_sca_test = Xt_rot_test / np.mean(np.linalg.norm(Xt_rot_train, axis=1), axis=0) * np.mean(np.linalg.norm(Xs, axis=1), axis=0)
return Xt_sca_test
Example
import matplotlib.pyplot as plt
import numpy as np
from pyriemann.tangentspace import tangent_space, untangent_space
from pyriemann.utils.mean import mean_riemann, mean_logeuclid
from pyriemann.estimation import Covariances
from pyriemann.transfer import encode_domains
from _tsa import TSAxPCA
n_classes = 2
n_clusters = 500
n_dims = 2
source_label = "source"
target_label = "target"
n_trials_source = n_classes*n_clusters
n_trials_target = n_trials_source
fake_mean = 0.5
fake_std = 0.5
theta = np.pi/5
fake_rot = np.array([[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]])
############# ADDED FALSE SOURCE ##################
Xs_means = 5 * np.array([[-1 / np.sqrt(n_classes * (n_classes - 1)) for _ in range(0, i)] +
[np.sqrt((n_classes - 1) / n_classes)] +
[-1 / np.sqrt(n_classes * (n_classes - 1)) for _ in range(i + 1, n_classes)] +
[0 for _ in range(n_dims - n_classes)] for i in range(n_classes)])
Xs_ts_list = [Xs_means[0] + (-(n_clusters - 1) + 2 * i) / (n_clusters - 1) * np.array(
[0, 5] + [0 for _ in range(n_dims - n_classes)]) for i in
range(n_clusters)] + \
[Xs_means[1] + (-(n_clusters - 1) + 2 * i) / (n_clusters - 1) * np.array(
[5, 0] + [0 for _ in range(n_dims - n_classes)]) for i in
range(n_clusters)]
ys = np.array([0 for _ in range(n_clusters)] + [1 for _ in range(n_clusters)])
Xs_ts = np.array(Xs_ts_list)
Xs_ts = Xs_ts + np.random.randn(n_clusters*n_classes, n_dims)
mean_source = np.mean(Xs_ts, axis=0)
Xs_ts = Xs_ts - mean_source
std_source = np.mean(np.linalg.norm(Xs_ts, axis=1), axis=0)
Xs_ts = Xs_ts / std_source
###### Create the target as a modification of the source data #############
yt = ys
# Xt_ts = np.multiply(1+5*np.array([[np.cos((i+1)*np.pi/2), np.sin((i+1)*np.pi/2), 0] for i in yt]), (np.random.normal(scale=0.15, size=(n_trials_target, 3)) + 0.4 * np.array([[i - 0.5, i - 0.5, 0] for i in yt])))
Xt_ts = Xs_ts
Xt_ts = np.dot(Xt_ts, fake_rot)
X = np.append(Xs_ts, Xt_ts, axis=0)
y = np.append(ys, yt)
yDom = np.append(np.repeat(source_label, n_trials_source), np.repeat(target_label, n_trials_target))
X_enc, y_enc = encode_domains(X, y, yDom)
tsa = TSAxPCA(target_label, n_clusters=5, eVar=1.0)
tsa.fit(X_enc, y_enc=y_enc)
Xt_ts_aligned = tsa.transform(X[yDom == target_label])
plt.figure()
plt.title("Source")
plt.scatter(Xs_ts[:, 0], Xs_ts[:, 1])
plt.xlim([-2, 2])
plt.ylim([-2, 2])
plt.figure()
plt.title("Target")
plt.scatter(Xt_ts[:, 0], Xt_ts[:, 1])
plt.xlim([-2, 2])
plt.ylim([-2, 2])
plt.figure()
plt.title("Target aligned")
plt.scatter(Xt_ts_aligned[:, 0], Xt_ts_aligned[:, 1])
plt.xlim([-2, 2])
plt.ylim([-2, 2])
plt.show()
cc @apmellot @antoinecollas
In the transfer learning module, add Tangent Space Alignement (TSA), see https://www.frontiersin.org/articles/10.3389/fnhum.2022.1049985/pdf