Open yangjun19950118 opened 6 months ago
![Uploading 1715390632499.jpg…]() 改成vit_h,或vit_l
models = { 'vit_b': './checkpoints/sam_vit_b_01ec64.pth', 'vit_l': './checkpoints/sam_vit_l_0b3195.pth', 'vit_h': './checkpoints/sam_vit_h_4b8939.pth' }
def get_sam_predictor(model_type='vit_b', device='cuda'):
sam = sam_model_registrymodel_type sam = sam.to(device)
predictor = SamPredictor(sam)
return predictor
def get_mask_generator(model_type='vit_b', device='cuda'): sam = sam_model_registrymodel_type sam = sam.to(device) mask_generator = SamAutomaticMaskGenerator( model=sam) return mask_generator