The repository provides code for running inference with the SegmentAnything Model (SAM), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
47.87k
stars
5.66k
forks
source link
SamAutomaticMaskGenerator after more than 2 in parallel process becomes very slow #791
I'm using torch.multiprocessing for parallel image splitting. I found when I on the number of parallel, after more than two SamAutomaticMaskGenerator generating speed will slow down. It starts out at a normal speed, about two or three seconds. But it slows down after a while, about 60-200 seconds. Does anyone know why that is? The running environment is H800.
the code is as following.
import math
import multiprocessing
import os
import time
import torch
import numpy as np
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import torch.multiprocessing as mp
Function to process each batch with a separate model in each process
def process_batch_parallel(samples, device, idx):
sam, mask_generator = create_model_and_generator(device) # Create model and generator per process
batchsize, , h, w = samples.shape
batch_masks = []
for i, img in enumerate(samples):
start = time.time()
img = img.to(device) # Move the image to the correct GPU
img_np = img.permute(1, 2, 0).cpu().numpy() # CPU conversion for mask generation
img_np = (img_np * 255).astype(np.uint8)
# Generate segmentation masks
masks = mask_generator.generate(img_np)
combined_mask = np.zeros((h, w), dtype=np.int16)
# # Sort masks by area in descending order
# masks = sorted(masks, key=lambda x: x['area'], reverse=True)
for j, mask in enumerate(masks):
segmentation = mask['segmentation']
combined_mask[segmentation] = j + 1
batch_masks.append(torch.tensor(combined_mask, dtype=torch.int16).cpu())
end = time.time()
print(f"Processed sample {i}/{len(samples)} in batch {idx}, time {end - start}s")
return torch.stack(batch_masks)
def process_batch_parallel2(rank, train_loader, devices, lefts, rights):
device = devices[rank]
left= lefts[rank]
right = rights[rank]
print(f"left to right is {left} to {right}, device is {device}, rank is {rank}")
sam, mask_generator = create_model_and_generator(device) # Create model and generator per process
for idx, (samples, targets) in enumerate(train_loader):
batch_masks = []
if idx > right :
return left, right
if idx < left :
continue
start_time = time.time()
for i, img in enumerate(samples):
batchsize, , h, w = samples.shape
start = time.time()
img_np = img.permute(1, 2, 0).cpu().numpy()
img_np = (img_np * 255).astype(np.uint8)
masks = mask_generator.generate(img_np)
combined_mask = np.zeros((h, w), dtype=np.int16)
for j, mask in enumerate(masks):
segmentation = mask['segmentation']
combined_mask[segmentation] = j + 1
batch_masks.append(torch.tensor(combined_mask, dtype=torch.int16).cpu())
end = time.time()
print(f"Processed sample {i}/{len(samples)} in batch {idx}, time {end - start}s")
batch = torch.stack(batch_masks)
torch.save(batch, f"sam_output/sam_batch{idx}.pth")
end_time = time.time()
print(f"Processed batch {idx-left+1}/{right - left + 1}, time {end_time - start_time}s")
return left, right
Use multiprocessing Pool, but now each process will handle a batch
devices = [f"cuda:{i+4}" for i in range(max_tasks)]
lefts = [(length*i) for i in range(max_tasks)]
rights = [(length*(i+1)-1) for i in range(max_tasks)]
mp.spawn(process_batch_parallel2, nprocs=max_tasks, args=(train_loader, devices, lefts, rights))
I'm using torch.multiprocessing for parallel image splitting. I found when I on the number of parallel, after more than two SamAutomaticMaskGenerator generating speed will slow down. It starts out at a normal speed, about two or three seconds. But it slows down after a while, about 60-200 seconds. Does anyone know why that is? The running environment is H800. the code is as following.
import math import multiprocessing import os import time import torch import numpy as np
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator from torch.utils.data import DataLoader from torchvision import transforms, datasets import torch.multiprocessing as mp
multiprocessing.set_start_method('spawn', force=True)
sam_checkpoint = model_type =
def load_ImageNet(ImageNet_PATH, batch_size=64, workers=3, pin_memory=True, batch_range=None): traindir = os.path.join(ImageNet_PATH, 'train') valdir = os.path.join(ImageNet_PATH, 'val')
Function to create a model and mask generator inside each process
def create_model_and_generator(device):
Load model in each process to ensure independent models
Function to process each batch with a separate model in each process
def process_batch_parallel(samples, device, idx): sam, mask_generator = create_model_and_generator(device) # Create model and generator per process batchsize, , h, w = samples.shape batch_masks = []
def process_batch_parallel2(rank, train_loader, devices, lefts, rights): device = devices[rank] left= lefts[rank] right = rights[rank] print(f"left to right is {left} to {right}, device is {device}, rank is {rank}") sam, mask_generator = create_model_and_generator(device) # Create model and generator per process for idx, (samples, targets) in enumerate(train_loader): batch_masks = [] if idx > right : return left, right if idx < left : continue start_time = time.time() for i, img in enumerate(samples): batchsize, , h, w = samples.shape start = time.time() img_np = img.permute(1, 2, 0).cpu().numpy() img_np = (img_np * 255).astype(np.uint8) masks = mask_generator.generate(img_np) combined_mask = np.zeros((h, w), dtype=np.int16) for j, mask in enumerate(masks): segmentation = mask['segmentation'] combined_mask[segmentation] = j + 1 batch_masks.append(torch.tensor(combined_mask, dtype=torch.int16).cpu()) end = time.time() print(f"Processed sample {i}/{len(samples)} in batch {idx}, time {end - start}s")
def process_in_parallel(train_loader, max_tasks=4): length = math.ceil(len(train_loader)/max_tasks)
Use multiprocessing Pool, but now each process will handle a batch
if name == 'main': max_tasks = 3
torch.set_num_threads(8)