Closed pzSuen closed 2 years ago
Hi, sorry for my late reply!
torch.Tensor
) and a list with the corresponding labels (as int
), you should be able to reproduce the results with the following code# given: patches (list with patches as torch.Tensors) and labels (list with domain labels)
# task: compute image features and plot corresponding umap from the paper
import torch
import umap
import matplotlib.pyplot as plt
from skimage.color import rgb2hsv, rgb2gray, rgb2lab, rgb2hed
def extract_features(patch):
rgb_mean = torch.mean(patch, dim=(1, 2))
hsv = torch.from_numpy(rgb2hsv(patch.permute(1, 2, 0))).permute(2, 0, 1)
hsv_mean = torch.mean(hsv, dim=(1, 2))
lab = torch.from_numpy(rgb2lab(patch.permute(1, 2, 0))).permute(2, 0, 1)
lab_mean = torch.mean(lab, dim=(1, 2))
hed = torch.from_numpy(rgb2hed(patch.permute(1, 2, 0))).permute(2, 0, 1)
hed_mean = torch.mean(hed, dim=(1, 2))
gray = torch.from_numpy(rgb2gray(patch.permute(1, 2, 0)))
gray_mean = torch.mean(gray).unsqueeze(0)
return torch.cat((rgb_mean, hsv_mean, lab_mean, hed_mean, gray_mean))
features = [extract_features(patch) for patch in patches]
features = torch.stack(features)
# compute the umap coordinates
reducer = umap.UMAP()
u = reducer.fit_transform(features)
# plot the umap dimension reduction with the domain labels
plt.scatter(u[:,0], u[:,1], c=labels, alpha=0.5)
plt.show()
Hi,
I have serval questions,
Thank you!