GUDHI / TDA-tutorial

A set of jupyter notebooks for the practice of TDA with the python Gudhi library together with popular machine learning and data sciences libraries.
MIT License
378 stars 114 forks source link

Persistence diagram for image data / embedding data #58

Open dgm2 opened 2 years ago

dgm2 commented 2 years ago

Hello, Thanks for maintaining this repo. Two questions on processing image datasets (e.g. torchvision MNIST).

1) is there an example on getting the persistence diagram of an image? 2) is it possible to get persistence diagram after the image has been convolved (e.g. linear layer) ? namely, instead of tensor with dimension (xxx, 2) we only have an an embedding of single dimension.

Example

input (X1) : 1x28x28 
emb1 = conv (X1) : 1x512
diag = RipsComplex( emb1)

I put RipsComplex but any object for persistence would be ok.

Many thanks!

VincentRouvreau commented 2 years ago

@dgm2 Why not using the cubical complex ?

from sklearn.datasets import fetch_openml
import matplotlib.pyplot as plt
import gudhi as gd

X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)

# X[17] is an '8'
cc = gd.CubicalComplex(top_dimensional_cells=X[17], dimensions=[28, 28])
diag = cc.persistence()
gd.plot_persistence_diagram(diag, legend=True)
plt.show()

Figure_1

# X[1] is an '0'
cc = gd.CubicalComplex(top_dimensional_cells=X[1], dimensions=[28, 28])
diag = cc.persistence()
gd.plot_persistence_diagram(diag, legend=True)
plt.show()

Figure_2

Another example is available on this tuto about cubical complex.

dgm2 commented 2 years ago

sounds good, thanks!

What would be the best way to replicate something like the following Dionysus code with GUDHI ?

1) get persistence diagram e.g. d1:

from dionysus import *
filtered = fill_freudenthal(image)
persistence = homology_persistence(filtered)
diag = init_diagrams(persistence, filtered)

2) compute distance between 2 diagrams. e.g. di.wasserstein_distance(d1, d2) I have tried with cubicalComplex, but I seem to get all zeros on wasserstein distance.

references fill_freudenthal

homology_persistence

many thanks

import numpy as np
from torchvision import datasets

from gudhi.cubical_complex import CubicalComplex
from gudhi.wasserstein import wasserstein_distance

def pers_diag(pts):
    pers = CubicalComplex(top_dimensional_cells=pts, dimensions=[28, 28]).persistence()
    res = np.array([list(b) for (_, b) in pers])
    return res

dataset2 = datasets.MNIST('../data', train=False, download=True)

diagrams = []
labels = []
n = 10
for dat, lab in zip(dataset2.data[:n], dataset2.train_labels[:n]):
    pts = dat.data.numpy().reshape(-1)
    diagrams.append(pers_diag(pts))
    labels.append(lab.item())

def print_wd(i, j):
    print(labels[i], labels[j], wasserstein_distance(diagrams[i], diagrams[j]))

for i, j in itertools.combinations(range(n), 2):
    print_wd(i, j)

output

labels 7 2 | wass 0.0 | bott 24.5
labels 7 1 | wass 0.0 | bott 52.0
labels 7 0 | wass 0.0 | bott 125.5
labels 7 4 | wass 0.0 | bott 18.0
labels 7 1 | wass 0.0 | bott 24.5
labels 7 4 | wass 0.0 | bott 32.0
labels 7 9 | wass 0.0 | bott 92.5
labels 7 5 | wass 0.0 | bott 126.5
labels 7 9 | wass 0.0 | bott 114.5
labels 2 1 | wass 0.0 | bott 52.0
labels 2 0 | wass 0.0 | bott 125.5
labels 2 4 | wass 0.0 | bott 26.5
labels 2 1 | wass 0.0 | bott 3.5
labels 2 4 | wass 0.0 | bott 40.0
labels 2 9 | wass 0.0 | bott 92.5
labels 2 5 | wass 0.0 | bott 126.5
labels 2 9 | wass 0.0 | bott 114.5
labels 1 0 | wass 0.0 | bott 125.5
labels 1 4 | wass 0.0 | bott 45.0
labels 1 1 | wass 0.0 | bott 52.0
labels 1 4 | wass 0.0 | bott 25.0
labels 1 9 | wass 0.0 | bott 92.5
labels 1 5 | wass 0.0 | bott 126.5
labels 1 9 | wass 0.0 | bott 114.5
labels 0 4 | wass 0.0 | bott 125.5
labels 0 1 | wass 0.0 | bott 125.5
labels 0 4 | wass 0.0 | bott 125.5
labels 0 9 | wass 0.0 | bott 66.0
labels 0 5 | wass 0.0 | bott 100.5
labels 0 9 | wass 0.0 | bott 38.0
labels 4 1 | wass 0.0 | bott 26.5
labels 4 4 | wass 0.0 | bott 25.0
labels 4 9 | wass 0.0 | bott 92.5
labels 4 5 | wass 0.0 | bott 126.5
labels 4 9 | wass 0.0 | bott 114.5
labels 1 4 | wass nan | bott 40.0
labels 1 9 | wass 0.0 | bott 92.5
labels 1 5 | wass 0.0 | bott 126.5
labels 1 9 | wass 0.0 | bott 114.5
labels 4 9 | wass 0.0 | bott 92.5
labels 4 5 | wass 0.0 | bott 126.5
labels 4 9 | wass 0.0 | bott 114.5
labels 9 5 | wass 0.0 | bott 100.5
labels 9 9 | wass 0.0 | bott 44.0
labels 5 9 | wass 0.0 | bott 100.5
VincentRouvreau commented 2 years ago

:thinking: strange to me your second point... I updated 2-3 things to your code, but yours was (almost) working:

import itertools
import numpy as np
from torchvision import datasets

from gudhi.cubical_complex import CubicalComplex
from gudhi.wasserstein import wasserstein_distance
from gudhi import bottleneck_distance

def pers_diag(pts):
    pers = CubicalComplex(top_dimensional_cells=pts, dimensions=[28, 28]).persistence()
    res = np.array([list(b) for (_, b) in pers])
    return res

dataset2 = datasets.MNIST('data', train=False, download=True)

diagrams = []
labels = []
n = 10
for dat, lab in zip(dataset2.data[:n], dataset2.train_labels[:n]):
    pts = dat.data.numpy().reshape(-1)
    diagrams.append(pers_diag(pts))
    labels.append(lab.item())

def print_wd(i, j):
    print("labels ", labels[i], labels[j], " | was ", wasserstein_distance(diagrams[i], diagrams[j]), " | bot ", bottleneck_distance(diagrams[i], diagrams[j]))

for i, j in itertools.combinations(range(n), 2):
    print_wd(i, j)

outputs:

labels  7 2  | was  107.0  | bot  24.5
labels  7 1  | was  119.0  | bot  52.0
labels  7 0  | was  241.0  | bot  125.5
labels  7 4  | was  90.0  | bot  18.0
labels  7 1  | was  109.0  | bot  24.5
labels  7 4  | was  104.5  | bot  32.0
labels  7 9  | was  132.0  | bot  92.5
labels  7 5  | was  288.5  | bot  126.5
labels  7 9  | was  248.5  | bot  114.5
labels  2 1  | was  123.0  | bot  52.0
labels  2 0  | was  141.0  | bot  125.5
labels  2 4  | was  146.0  | bot  26.5
labels  2 1  | was  8.0  | bot  3.5
labels  2 4  | was  147.5  | bot  40.0
labels  2 9  | was  201.0  | bot  92.5
labels  2 5  | was  282.0  | bot  126.5
labels  2 9  | was  197.0  | bot  114.5
labels  1 0  | was  222.0  | bot  125.5
labels  1 4  | was  115.5  | bot  45.0
labels  1 1  | was  117.5  | bot  52.0
labels  1 4  | was  114.0  | bot  25.0
labels  1 9  | was  196.5  | bot  92.5
labels  1 5  | was  253.0  | bot  126.5
labels  1 9  | was  274.5  | bot  114.5
labels  0 4  | was  274.0  | bot  125.5
labels  0 1  | was  136.5  | bot  125.5
labels  0 4  | was  281.5  | bot  125.5
labels  0 9  | was  187.0  | bot  66.0
labels  0 5  | was  161.0  | bot  100.5
labels  0 9  | was  110.0  | bot  38.0
labels  4 1  | was  140.0  | bot  26.5
labels  4 4  | was  118.0  | bot  25.0
labels  4 9  | was  192.0  | bot  92.5
labels  4 5  | was  323.0  | bot  126.5
labels  4 9  | was  290.0  | bot  114.5
labels  1 4  | was  148.0  | bot  40.0
labels  1 9  | was  205.0  | bot  92.5
labels  1 5  | was  276.5  | bot  126.5
labels  1 9  | was  200.5  | bot  114.5
labels  4 9  | was  175.0  | bot  92.5
labels  4 5  | was  315.5  | bot  126.5
labels  4 9  | was  280.0  | bot  114.5
labels  9 5  | was  250.0  | bot  100.5
labels  9 9  | was  179.5  | bot  44.0
labels  5 9  | was  222.5  | bot  100.5
VincentRouvreau commented 2 years ago

@dgm2 what is your gudhi version ? python -c "import gudhi; print(gudhi.__version__)"

VincentRouvreau commented 2 years ago

Here is an example on how to do the same code with dionysus and gudhi:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
import dionysus as d
import gudhi as gd

X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)

# a zero
a = X[1].reshape((28,28))
#a = np.random.random((10,10))
plt.matshow(a)
plt.colorbar()
plt.show()

f_lower_star = d.fill_freudenthal(a)
p = d.homology_persistence(f_lower_star)
dgms = d.init_diagrams(p, f_lower_star)

for i,dgm in enumerate(dgms):
    print(i)
    for pt in dgm:
        print(pt)

# 0
# (0,inf)
# (0,165)
# (84,96)
# 1
# (0,255)
# (173,252)
# (223,253)
# (225,252)
# (225,253)
# (238,253)
# (240,253)
# (246,253)
# (252,253)
# (252,253)
# (252,253)
# (252,253)
# (253,255)

cc = gd.CubicalComplex(top_dimensional_cells=a)
cc.compute_persistence()
cc.persistence_intervals_in_dimension(0)
# array([[ 84.,  96.],
#        [  0., 165.],
#        [  0.,  inf]])
cc.persistence_intervals_in_dimension(1)
#array([[173., 252.],
#       [225., 252.],
#       [252., 253.],
#       [246., 253.],
#       [240., 253.],
#       [238., 253.],
#       [252., 253.],
#       [252., 253.],
#       [225., 253.],
#       [252., 253.],
#       [237., 253.],
#       [223., 253.],
#       [253., 255.],
#       [  0., 255.]])