ys-zong / conST

conST: an interpretable multi-modal contrastive learning framework for spatial transcriptomics
MIT License
21 stars 4 forks source link

The problem encountered in extracting features from spot patches #5

Open tqwang743 opened 1 year ago

tqwang743 commented 1 year ago

hello, I was trying to retrain the conST model without using the trained weights conST_151673.pth, but i I encountered difficulties while performing the following step . So can you share the code about this step. image Thank you!

wenhuidu commented 1 year ago

I have the same problem

bbchond commented 5 months ago

I face the same problem, and I use the following code to extract features from spot patches:

from tqdm import tqdm
from MAEpytorch.modeling_pretrain import pretrain_mae_base_patch16_224
from PIL import Image
import torch.backends.cudnn as cudnn
from MAEpytorch.datasets import DataAugmentationForMAE

cudnn.benchmark = True
input_size = 224

# load image, and then crop image into multiple 224 * 224 size
image = adata_h5.uns["spatial"][section_id]['images']['hires']
if image.dtype == np.float32 or image.dtype == np.float64:
    image = (image * 255).astype(np.uint8)

scale_factor = adata_h5.uns['spatial'][section_id]['scalefactors']['tissue_hires_scalef']
image_coord = adata_h5.obsm['spatial'] * scale_factor
# patches = []
crop_size = input_size / 2

# load MAE model
# model = get_model(args)
model = pretrain_mae_base_patch16_224()
patch_size = model.encoder.patch_embed.patch_size
print("Patch size = %s" % str(patch_size))

window_size = (input_size // patch_size[0], input_size // patch_size[1])
model.to(device)

checkpoint = torch.load('./MAEpytorch/pretrain_mae_vit_base_mask_0.75_400e.pth', map_location='cpu')
model.load_state_dict(checkpoint['model'])
model.eval()

params.window_size = (input_size // patch_size[0], input_size // patch_size[1])
# extract morphological information from image patches
transforms = DataAugmentationForMAE(params)
image_pillow = Image.fromarray(image)
features_all = []

with tqdm(total=len(adata_h5), desc='image cropping and feature extracting...', bar_format='{l_bar}{bar} [ time left: {remaining} ]') as pbar:
    for image_row, image_col in zip(adata.obs['image_row'], adata.obs['image_col']):
        patch = image_pillow.crop((image_col - crop_size, image_row - crop_size, image_col + crop_size, image_row + crop_size))
        patch.thumbnail((input_size, input_size), Image.LANCZOS)
        patch.resize((input_size, input_size))
        # patches.append(patch)

        img, bool_masked_pos = transforms(patch)
        bool_masked_pos = torch.from_numpy(bool_masked_pos)

        with torch.no_grad():
            img = img[None, :]
            bool_masked_pos = bool_masked_pos[None, :]
            img = img.to(device, non_blocking=True)
            bool_masked_pos = bool_masked_pos.to(device, non_blocking=True).flatten(1).to(torch.bool)

            features = model.encoder.forward(img, bool_masked_pos)
            features_pool = features.mean(dim = 1)
            features_all.append(features_pool.detach().cpu().numpy())
        pbar.update(1)

You can use phenograph python package to visualize the domain identification result based on the extracted morphological features:

from sklearn.decomposition import PCA
import phenograph

params.cell_num = adata_h5.shape[0]
img_transformed = np.asarray(features_all).reshape(params.cell_num, -1)
img_transformed = (img_transformed - img_transformed.mean()) / img_transformed.std() * adata_X.std() + adata_X.mean()

pca = PCA(n_components=50, random_state=42)
img_feat_pca = pca.fit_transform(img_transformed)

adata_h5.obsm['image_feat_pca'] = img_feat_pca
graph_label, _, _ = phenograph.cluster(adata_h5.obsm['image_feat_pca'])
adata_h5.obs['graph_label'] = graph_label
adata_h5.obs['graph_label'] = adata_h5.obs['graph_label'].astype('category')
sc.pl.spatial(adata_h5, color='graph_label')

However, there are some bugs need to be fixed because the dims of network layer is not correspond to the latent dims, and the final domain identification sometimes strange. Hope it will help.