sophiajw / HistAuGAN

43 stars 12 forks source link

Questions about data #2

Closed pzSuen closed 2 years ago

pzSuen commented 2 years ago

Hi,

I have serval questions,

  1. The tumor patches number of different centers of Camelyon17 varies hugely. For example, the tumor patches number of center UMUC is about the sum of the other four centers. Which means the data number is different. How do you deal with it?
  2. You plot the UMAP figures and compute the mLD. I want to follow your work, but it seems you don't share the code. Would you like to share it?

Thank you!

sophiajw commented 2 years ago

Hi, sorry for my late reply!

  1. On Camelyon17, we balanced the number of patches per domain, but I also trained the model later on other imbalanced datasets and it worked fine as well.
  2. Unfortunately, my collaborators created the figure for the paper, so I can't share the original code with you. Given you have a list of patches (as torch.Tensor) and a list with the corresponding labels (as int), you should be able to reproduce the results with the following code
# given: patches (list with patches as torch.Tensors) and labels (list with domain labels)
# task: compute image features and plot corresponding umap from the paper

import torch
import umap
import matplotlib.pyplot as plt

from skimage.color import rgb2hsv, rgb2gray, rgb2lab, rgb2hed

def extract_features(patch):
    rgb_mean = torch.mean(patch, dim=(1, 2))

    hsv = torch.from_numpy(rgb2hsv(patch.permute(1, 2, 0))).permute(2, 0, 1)
    hsv_mean = torch.mean(hsv, dim=(1, 2))

    lab = torch.from_numpy(rgb2lab(patch.permute(1, 2, 0))).permute(2, 0, 1)
    lab_mean = torch.mean(lab, dim=(1, 2))

    hed = torch.from_numpy(rgb2hed(patch.permute(1, 2, 0))).permute(2, 0, 1)
    hed_mean = torch.mean(hed, dim=(1, 2))

    gray = torch.from_numpy(rgb2gray(patch.permute(1, 2, 0)))
    gray_mean = torch.mean(gray).unsqueeze(0)

    return torch.cat((rgb_mean, hsv_mean, lab_mean, hed_mean, gray_mean))

features = [extract_features(patch) for patch in patches]
features = torch.stack(features)

# compute the umap coordinates
reducer = umap.UMAP() 
u = reducer.fit_transform(features)

# plot the umap dimension reduction with the domain labels
plt.scatter(u[:,0], u[:,1], c=labels, alpha=0.5)
plt.show()