facebookresearch / segment-anything

The repository provides code for running inference with the SegmentAnything Model (SAM), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
47.87k stars 5.66k forks source link

SamAutomaticMaskGenerator after more than 2 in parallel process becomes very slow #791

Open 0930mcx opened 2 weeks ago

0930mcx commented 2 weeks ago

I'm using torch.multiprocessing for parallel image splitting. I found when I on the number of parallel, after more than two SamAutomaticMaskGenerator generating speed will slow down. It starts out at a normal speed, about two or three seconds. But it slows down after a while, about 60-200 seconds. Does anyone know why that is? The running environment is H800. the code is as following.

import math import multiprocessing import os import time import torch import numpy as np

from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator from torch.utils.data import DataLoader from torchvision import transforms, datasets import torch.multiprocessing as mp

multiprocessing.set_start_method('spawn', force=True)

sam_checkpoint = model_type =

def load_ImageNet(ImageNet_PATH, batch_size=64, workers=3, pin_memory=True, batch_range=None): traindir = os.path.join(ImageNet_PATH, 'train') valdir = os.path.join(ImageNet_PATH, 'val')

normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225])

# 定义训练集数据集
train_dataset = datasets.ImageFolder(
    traindir,
    transforms.Compose([transforms.Resize((192, 192)),
                        transforms.RandomHorizontalFlip(),
                        transforms.ToTensor(),
                        normalizer])
)

# 定义验证集数据集
val_dataset = datasets.ImageFolder(
    valdir,
    transforms.Compose([transforms.Resize((192, 192)),
                        transforms.ToTensor(),
                        normalizer])
)

# 创建训练集数据加载器
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=workers,
    pin_memory=pin_memory
)

# 创建验证集数据加载器
val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=workers,
    pin_memory=pin_memory
)

# 如果指定了 num_batches,截断加载的批次数
if batch_range is not None:
    # 获取训练集和验证集的总样本数
    total_train_samples = len(train_dataset)
    total_val_samples = len(val_dataset)

    # 计算总的批次数
    max_train_batches = total_train_samples // batch_size
    max_val_batches = total_val_samples // batch_size
    print(len(train_loader))
    # 控制批次数,不超过指定的批次数
    train_loader = iter(train_loader)
    val_loader = iter(val_loader)

    # 如果指定了 batch_range,获取特定的批次区间
    if batch_range is not None:
        start_batch, end_batch = batch_range
        # 确保批次范围不超出最大批次数
        start_batch = min(start_batch, max_train_batches)
        end_batch = min(end_batch, max_train_batches)

        # 获取特定范围的批次
        train_loader = [next(train_loader) for _ in range(start_batch, end_batch)]
        val_loader = [next(val_loader) for _ in range(start_batch, end_batch)]

return train_loader, val_loader

Function to create a model and mask generator inside each process

def create_model_and_generator(device):

Load model in each process to ensure independent models

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device)

mask_generator = SamAutomaticMaskGenerator(model=sam,
                                           points_per_side=32,
                                           pred_iou_thresh=0.86,
                                           stability_score_thresh=0.92,
                                           crop_n_layers=1,
                                           crop_n_points_downscale_factor=2,
                                           min_mask_region_area=100)
return sam, mask_generator

Function to process each batch with a separate model in each process

def process_batch_parallel(samples, device, idx): sam, mask_generator = create_model_and_generator(device) # Create model and generator per process batchsize, , h, w = samples.shape batch_masks = []

for i, img in enumerate(samples):
    start = time.time()
    img = img.to(device)  # Move the image to the correct GPU
    img_np = img.permute(1, 2, 0).cpu().numpy()  # CPU conversion for mask generation
    img_np = (img_np * 255).astype(np.uint8)

    # Generate segmentation masks
    masks = mask_generator.generate(img_np)
    combined_mask = np.zeros((h, w), dtype=np.int16)

    # # Sort masks by area in descending order
    # masks = sorted(masks, key=lambda x: x['area'], reverse=True)
    for j, mask in enumerate(masks):
        segmentation = mask['segmentation']
        combined_mask[segmentation] = j + 1

    batch_masks.append(torch.tensor(combined_mask, dtype=torch.int16).cpu())
    end = time.time()
    print(f"Processed sample {i}/{len(samples)} in batch {idx}, time {end - start}s")

return torch.stack(batch_masks)

def process_batch_parallel2(rank, train_loader, devices, lefts, rights): device = devices[rank] left= lefts[rank] right = rights[rank] print(f"left to right is {left} to {right}, device is {device}, rank is {rank}") sam, mask_generator = create_model_and_generator(device) # Create model and generator per process for idx, (samples, targets) in enumerate(train_loader): batch_masks = [] if idx > right : return left, right if idx < left : continue start_time = time.time() for i, img in enumerate(samples): batchsize, , h, w = samples.shape start = time.time() img_np = img.permute(1, 2, 0).cpu().numpy() img_np = (img_np * 255).astype(np.uint8) masks = mask_generator.generate(img_np) combined_mask = np.zeros((h, w), dtype=np.int16) for j, mask in enumerate(masks): segmentation = mask['segmentation'] combined_mask[segmentation] = j + 1 batch_masks.append(torch.tensor(combined_mask, dtype=torch.int16).cpu()) end = time.time() print(f"Processed sample {i}/{len(samples)} in batch {idx}, time {end - start}s")

    batch = torch.stack(batch_masks)
    torch.save(batch, f"sam_output/sam_batch{idx}.pth")
    end_time = time.time()
    print(f"Processed batch {idx-left+1}/{right - left + 1}, time {end_time - start_time}s")
return left, right

def process_in_parallel(train_loader, max_tasks=4): length = math.ceil(len(train_loader)/max_tasks)

Use multiprocessing Pool, but now each process will handle a batch

devices = [f"cuda:{i+4}" for i in range(max_tasks)]
lefts = [(length*i) for i in range(max_tasks)]
rights = [(length*(i+1)-1) for i in range(max_tasks)]
mp.spawn(process_batch_parallel2, nprocs=max_tasks, args=(train_loader, devices, lefts, rights))

if name == 'main': max_tasks = 3

torch.set_num_threads(8)

# Load dataset
train_loader, val_loader = load_ImageNet(
    "", 4, 128, True, batch_range=(0, 30))
print(len(train_loader))
print("Start processing batches")
start = time.time()
# Start parallel processing with a limit of max_tasks simultaneous tasks
process_in_parallel(train_loader, max_tasks=max_tasks)
end = time.time()
print(f"Total processing time: {end - start}s")