lukemelas / deep-spectral-segmentation

[CVPR 2022] Deep Spectral Methods: A Surprisingly Strong Baseline for Unsupervised Semantic Segmentation and Localization
227 stars 41 forks source link

Question about extract_eigs #3

Closed naoki7090624 closed 2 years ago

naoki7090624 commented 2 years ago

Thank you for sharing great work!!

I have two questions about extract eigenvectors.

  1. python extract.py extract_eigs generate the eigenvectors on the input image path (e.g., /home/naoki/deep-spectral-segmentation/testdata/images/014583.pth) and takes about 10 seconds per a image. Is this normal?
  2. How to get the map of eigenvectors like demo?

Thank you in advance.

lukemelas commented 2 years ago

Hi, thanks for your questions!

For the first question, that seems quite slow. The eigenvector extraction code runs on CPU, so how much CPU compute are you using (e.g. how many cores do you have)? Also, what resolution are your images (or are you using PASCAL VOC)?

For the second question, I am simply using matplotlib with the standard color scheme. If you would like, I can also upload a code snippet for that, just let me know.

All the best, Luke

naoki7090624 commented 2 years ago

Thank you for your response.

My images (custom datasets) are 512*512, and my cpu has 12 cores. Also, I ran the code in docker enviroments with shm-size=2GB.

I would appreciate it if you could share the code.

99991 commented 2 years ago

10 seconds sounds too long. I played around with a few different solvers and got results in 0.2 seconds. Using the GPU does not make it any faster since initializing the GPU takes much longer than the CPU computation, at least for a single image.

import torch
import numpy as np
import scipy.sparse.linalg
import matplotlib.pyplot as plt
import time

# Provide your path to some feature vectors here
# You can obtain features like described here:
# https://github.com/lukemelas/deep-spectral-segmentation#step-1-feature-extraction
path = "sheep.pth"
# Number of eigenvectors to compute plus 1 (zeroth eigenvector will be noise)
K = 6 + 1

f = torch.load(path)

_, _, h, w = f["shape"]
h = h // f["patch_size"]
w = w // f["patch_size"]
features = f["k"][0].cpu().numpy()
W = features @ features.T
W[W < 0] = 0

D = np.diag(W.sum(axis=1))

start_time = time.perf_counter()

# 0.18 seconds
eigenvalues1, eigenvectors1 = scipy.sparse.linalg.eigsh(D - W, k=K, sigma=0, which='LM', M=D)

elapsed_time = time.perf_counter() - start_time

print(elapsed_time, "seconds to compute eigenvectors with scipy.sparse.linalg.eigsh")

# Normalize rows
W /= W.sum(axis=1, keepdims=True)
D = np.diag(W.sum(axis=1))

start_time = time.perf_counter()

# 0.24 seconds, but can change iterations to make it go slower or even faster
eigenvalues2, eigenvectors2 = torch.lobpcg(torch.from_numpy(D - W), k=K, largest=False, niter=30)

elapsed_time = time.perf_counter() - start_time

print(elapsed_time, "seconds to compute eigenvectors with torch.lobpcg")

start_time = time.perf_counter()

# 3 seconds
eigenvalues3, eigenvectors3 = np.linalg.eigh(D - W)

elapsed_time = time.perf_counter() - start_time

print(elapsed_time, "seconds to compute eigenvectors with np.linalg.eigh")

plt.figure(figsize=(10, 8))
for j, (title, eigenvectors) in enumerate([
    ("scipy.sparse.linalg.eigsh", eigenvectors1),
    ("torch.lobpcg", eigenvectors2),
    ("np.linalg.eigh", eigenvectors3),
]):
    for i in range(1, K):
        plt.subplot(3, K - 1, i + j * (K - 1))
        if i == 1: plt.title(title)

        image = eigenvectors[:, i].reshape(h, w)

        # Make image positive because it looks nicer
        if image.mean() < 0: image = -image

        plt.imshow(image)
        plt.axis('off')
plt.tight_layout()
plt.show()

sheep_eigenvectors

The input was the sheep image 2010_001256.jpg which is also on https://huggingface.co/spaces/lukemelas/deep-spectral-segmentation and displayed in the paper. I used extract.py with --model_name dino_vits8 here.

The third row looks a bit different, which is why I am not super confident in the results, but I think this is the general idea.

I also tried to add a weighted KNN affinity matrix, but it did not improve the results much. In Table 7 of the supplementary material, $\lambda_{KNN}$ did not seem to have a huge impact either, so I think this is expected. It is nice to know that a simpler approach already works so well.

Lastly, I tried a soft kernel instead of the binary kernel we used in PyMatting. For alpha matting, it absolutely does not make a difference whether you use $k(i, j) = 1 - ||X(i) - X(j)|| / C$ or $k(i, j) = 1$ since values are usually very close to 1 anyway with the first variant. We did not document it very well previously (sorry about that), but now you also have an option to specify a soft kernel if you want to. Anyway, I am glad to say that results for eigenvector computation with soft kernel are almost identical to those with binary kernel, so it does not matter which one you use.

https://github.com/lukemelas/deep-spectral-segmentation/blob/c600861677af1d33817e0c4dcc92ec024ff2f69a/extract/extract_utils.py#L184

I think I can also answer the question here:

https://github.com/lukemelas/deep-spectral-segmentation/blob/513d954df2236ecf81921f3d8f1560bf4c55709e/extract/extract.py#L222

Since the matrix W_feat is quite dense (excluding the values previously clamped to 0, but those are not that many), W_comb will also be dense. Working with dense matrices instead of sparse matrices here is the better choice.

naoki7090624 commented 2 years ago

Thank you very much for sharing.

The following one line took 16 seconds. https://github.com/lukemelas/deep-spectral-segmentation/blob/513d954df2236ecf81921f3d8f1560bf4c55709e/extract/extract.py#L227

With np.linalg.eigh, I got the eigenvalues and eigenvectors in the same time as you.

16.388382498174906 seconds to compute eigenvectors with scipy.sparse.linalg.eigsh
7.680276275612414 seconds to compute eigenvectors with torch.lobpcg
3.948524377308786 seconds to compute eigenvectors with np.linalg.eigh

I think scipy can generate better results but slower. If you know why my scipy(1.8.0) is slower, please let me know.

99991 commented 2 years ago

torch.lobpcg is also much slower than it should be. Maybe the computed features are unlucky somehow. Could you upload them somewhere for testing? I'll upload the sheep.pth later so you can test with them, too. EDIT: sheep.pth features

Alternatively, it might be possible to locate the slow part with a profiler: python -m cProfile -s tottime main.py (remove the plotting code first)

SciPy offloads the computation to ARPACK https://github.com/scipy/scipy/blob/v1.8.1/scipy/sparse/linalg/_eigen/arpack/arpack.py while PyTorch implements most of the algorithm in Python https://github.com/pytorch/pytorch/blob/master/torch/_lobpcg.py so it might be easier to find the issue in the PyTorch code. NumPy uses LAPACK https://github.com/numpy/numpy/blob/08772f91455db66810995db5e9d0671f91e027ed/numpy/linalg/linalg.py#L1379

naoki7090624 commented 2 years ago

Thank you for helping.

I tested with sheep.pth, but it took more times.

28.736304740421474 seconds to compute eigenvectors with scipy.sparse.linalg.eigsh
10.58264269027859 seconds to compute eigenvectors with torch.lobpcg
13.395108604803681 seconds to compute eigenvectors with np.linalg.eigh

My feature path is here.

I checked the slow part with python -m cProfile -s tottime eigmap.py. The following is a line with a large tottime.

ncalls  tottime  percall  cumtime  percall filename:lineno(function)
1   26.445   26.445   26.450   26.450 _decomp_lu.py:15(lu_factor)
1   18.803   18.803   18.812   18.812 linalg.py:1336(eigh)
378    9.412    0.025    9.412    0.025 {built-in method matmul}
1    0.253    0.253   58.435   58.435 eigmap.py:1(<module>)
30    0.023    0.001    9.531    0.318 _lobpcg.py:847(_update_ortho)
30    0.001    0.000    9.559    0.319 _lobpcg.py:706(update)
1    0.000    0.000   27.305   27.305 arpack.py:1350(eigsh)
1    0.000    0.000    9.617    9.617 _lobpcg.py:338(lobpcg)
1    0.000    0.000   26.450   26.450 arpack.py:934(__init__)
1    0.000    0.000   18.812   18.812 <__array_function__ internals>:2(eigh)
1    0.000    0.000   26.450   26.450 arpack.py:1045(get_inv_matvec)
1    0.000    0.000   26.450   26.450 arpack.py:1055(get_OPinv_matvec)
naoki7090624 commented 2 years ago

When I tested on my local PC, scipy and torch worked fast.

0.027801500000000035 seconds to compute eigenvectors with scipy.sparse.linalg.eigsh
0.0345603000000001 seconds to compute eigenvectors with torch.lobpcg
0.18421749999999992 seconds to compute eigenvectors with np.linalg.eigh

I don't know why, but I think its due to my docker environments. Thanks a lot for helping. I will check my enviroments and run your code on local.

csyanbin commented 1 year ago

Hi @99991, the "# Normalize rows" seems to generate an asymmetric Laplacian matrix, which is different from $D^{-1/2} (D-W) D^{-1/2}$. I tried to implement Laplacian with the following

D_inv = np.diag(1./W.sum(axis=1))
D_inv_sqrt = np.sqrt(D_inv)
L = np.matmul(np.matmul(D_inv_sqrt, D-W), [D_inv_sqrt)]

The results look like the following: Figure_1