Closed jieruyao49 closed 1 year ago
pt文件就是一些存储下来的特征向量,类似.npy这种,供参考:
import os
import glob
import cv2
import timm
import torch
import numpy as np
import tqdm
import gc
model = timm.create_model('resnet50', pretrained=True)
model.fc=torch.nn.Identity()
model.cuda()
model.eval()
def extract_feas(model,bag_path,save_dir=''):
print(bag_path)
noraml_or_abnormal=bag_path.split('\\')[-3]
st_name=bag_path.split('\\')[-2]
bag_name=bag_path.split('\\')[-1]
all_instances=os.listdir(bag_path)
embed=[]
for instance in tqdm.tqdm(all_instances):
instance_path=os.path.join(bag_path,instance)
patch=cv2.imread(instance_path)
cv2.cvtColor(patch,cv2.COLOR_BGR2RGB)
patch=torch.tensor(np.transpose(patch,(2,0,1)))
patch=patch.cuda().float().unsqueeze(0)
f=model(patch)
f_np=f.detach().cpu().numpy().squeeze(0)
embed.append(f_np)
gc.collect()
feas=np.array(embed)
print(feas.shape)
save_path=os.path.join(save_dir,noraml_or_abnormal,st_name,bag_name)
if not os.path.exists(save_path):
os.makedirs(save_path)
print(save_path)
np.save(os.path.join(save_path,bag_name+'.npy'),feas)
if __name__ == "__main__":
all_bags=glob.glob(r'Patch\img\*\*\*')
for bag_path in all_bags:
extract_feas(model,bag_path)
We use the CLAM method to process WSI, and we can acquire the matching pt file, which corresponds to all patches in each WSI's features.
The directory of the camelyon16 dataset I downloaded is as follows. Where should I download the .pt files?