Open tqwang743 opened 1 year ago
I have the same problem
I face the same problem, and I use the following code to extract features from spot patches:
from tqdm import tqdm
from MAEpytorch.modeling_pretrain import pretrain_mae_base_patch16_224
from PIL import Image
import torch.backends.cudnn as cudnn
from MAEpytorch.datasets import DataAugmentationForMAE
cudnn.benchmark = True
input_size = 224
# load image, and then crop image into multiple 224 * 224 size
image = adata_h5.uns["spatial"][section_id]['images']['hires']
if image.dtype == np.float32 or image.dtype == np.float64:
image = (image * 255).astype(np.uint8)
scale_factor = adata_h5.uns['spatial'][section_id]['scalefactors']['tissue_hires_scalef']
image_coord = adata_h5.obsm['spatial'] * scale_factor
# patches = []
crop_size = input_size / 2
# load MAE model
# model = get_model(args)
model = pretrain_mae_base_patch16_224()
patch_size = model.encoder.patch_embed.patch_size
print("Patch size = %s" % str(patch_size))
window_size = (input_size // patch_size[0], input_size // patch_size[1])
model.to(device)
checkpoint = torch.load('./MAEpytorch/pretrain_mae_vit_base_mask_0.75_400e.pth', map_location='cpu')
model.load_state_dict(checkpoint['model'])
model.eval()
params.window_size = (input_size // patch_size[0], input_size // patch_size[1])
# extract morphological information from image patches
transforms = DataAugmentationForMAE(params)
image_pillow = Image.fromarray(image)
features_all = []
with tqdm(total=len(adata_h5), desc='image cropping and feature extracting...', bar_format='{l_bar}{bar} [ time left: {remaining} ]') as pbar:
for image_row, image_col in zip(adata.obs['image_row'], adata.obs['image_col']):
patch = image_pillow.crop((image_col - crop_size, image_row - crop_size, image_col + crop_size, image_row + crop_size))
patch.thumbnail((input_size, input_size), Image.LANCZOS)
patch.resize((input_size, input_size))
# patches.append(patch)
img, bool_masked_pos = transforms(patch)
bool_masked_pos = torch.from_numpy(bool_masked_pos)
with torch.no_grad():
img = img[None, :]
bool_masked_pos = bool_masked_pos[None, :]
img = img.to(device, non_blocking=True)
bool_masked_pos = bool_masked_pos.to(device, non_blocking=True).flatten(1).to(torch.bool)
features = model.encoder.forward(img, bool_masked_pos)
features_pool = features.mean(dim = 1)
features_all.append(features_pool.detach().cpu().numpy())
pbar.update(1)
You can use phenograph python package to visualize the domain identification result based on the extracted morphological features:
from sklearn.decomposition import PCA
import phenograph
params.cell_num = adata_h5.shape[0]
img_transformed = np.asarray(features_all).reshape(params.cell_num, -1)
img_transformed = (img_transformed - img_transformed.mean()) / img_transformed.std() * adata_X.std() + adata_X.mean()
pca = PCA(n_components=50, random_state=42)
img_feat_pca = pca.fit_transform(img_transformed)
adata_h5.obsm['image_feat_pca'] = img_feat_pca
graph_label, _, _ = phenograph.cluster(adata_h5.obsm['image_feat_pca'])
adata_h5.obs['graph_label'] = graph_label
adata_h5.obs['graph_label'] = adata_h5.obs['graph_label'].astype('category')
sc.pl.spatial(adata_h5, color='graph_label')
However, there are some bugs need to be fixed because the dims of network layer is not correspond to the latent dims, and the final domain identification sometimes strange. Hope it will help.
hello, I was trying to retrain the conST model without using the trained weights conST_151673.pth, but i I encountered difficulties while performing the following step . So can you share the code about this step. Thank you!