tianrun-chen / SAM-Adapter-PyTorch

Adapting Meta AI's Segment Anything to Downstream Tasks with Adapters and Prompts
MIT License
976 stars 84 forks source link

如何在单张GPU上运行? #85

Open shidizai-swpu opened 1 month ago

shidizai-swpu commented 1 month ago

报错:torch.distributed.elastic.multiprocessing.errors.ChildFailedError。请问如何能够在单张GPU上运行?

chooooock commented 3 weeks ago

你好,请问你解决这个问题了吗?我也想了解

li-pengcheng commented 3 weeks ago

你好,请问你解决这个问题了吗?我也想了解


import argparse
import os

import yaml from tqdm import tqdm from torch.utils.data import DataLoader from torch.optim.lr_scheduler import CosineAnnealingLR

import datasets import models import utils from statistics import mean import torch import torch.distributed as dist

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def make_data_loader(spec, tag=''): if spec is None: return None

dataset = datasets.make(spec['dataset'])
dataset = datasets.make(spec['wrapper'], args={'dataset': dataset})

if __name__ == '__main__':
    log('{} dataset: size={}'.format(tag, len(dataset)))
    for k, v in dataset[0].items():
        if torch.is_tensor(v):
            log('  {}: shape={}'.format(k, tuple(v.shape)))

loader = DataLoader(dataset, batch_size=spec['batch_size'],
                    shuffle=True, num_workers=8, pin_memory=True, drop_last=True)
return loader

def make_data_loaders(): train_loader = make_data_loader(config.get('train_dataset'), tag='train') val_loader = make_data_loader(config.get('val_dataset'), tag='val') return train_loader, val_loader

def eval_psnr(loader, model, eval_type=None): model.eval()

if eval_type == 'f1':
    metric_fn = utils.calc_f1
    metric1, metric2, metric3, metric4 = 'f1', 'auc', 'none', 'none'
elif eval_type == 'fmeasure':
    metric_fn = utils.calc_fmeasure
    metric1, metric2, metric3, metric4 = 'f_mea', 'mae', 'none', 'none'
elif eval_type == 'ber':
    metric_fn = utils.calc_ber
    metric1, metric2, metric3, metric4 = 'shadow', 'non_shadow', 'ber', 'none'
elif eval_type == 'cod':
    metric_fn = utils.calc_cod
    metric1, metric2, metric3, metric4 = 'sm', 'em', 'wfm', 'mae'
elif eval_type == 'kvasir':
    metric_fn = utils.calc_kvasir
    metric1, metric2, metric3, metric4 = 'dice', 'iou', 'none', 'none'

pbar = tqdm(total=len(loader), leave=False, desc='val')

val_metric1 = 0
val_metric2 = 0
val_metric3 = 0
val_metric4 = 0
cnt = 0

for batch in loader:
    for k, v in batch.items():
        if torch.is_tensor(v):
            batch[k] = v.cuda()

    inp = batch['inp']

    pred = torch.sigmoid(model.infer(inp))

    result1, result2, result3, result4 = metric_fn(pred, batch['gt'])
    val_metric1 += (result1 * pred.shape[0])
    val_metric2 += (result2 * pred.shape[0])
    val_metric3 += (result3 * pred.shape[0])
    val_metric4 += (result4 * pred.shape[0])
    cnt += pred.shape[0]
    if pbar is not None:
        pbar.update(1)

if pbar is not None:
    pbar.close()

return val_metric1 / cnt, val_metric2 / cnt, val_metric3 / cnt, val_metric4 / cnt, metric1, metric2, metric3, metric4

def prepare_training(): if config.get('resume') is not None: model = models.make(config['model']).cuda() optimizer = utils.make_optimizer( model.parameters(), config['optimizer']) epoch_start = config.get('resume') + 1 else: model = models.make(config['model']).cuda() optimizer = utils.make_optimizer( model.parameters(), config['optimizer']) epoch_start = 1 max_epoch = config.get('epoch_max') lr_scheduler = CosineAnnealingLR(optimizer, max_epoch, eta_min=config.get('lr_min')) log('model: #params={}'.format(utils.compute_num_params(model, text=True))) return model, optimizer, epoch_start, lr_scheduler

def train(train_loader, model): model.train()

pbar = tqdm(total=len(train_loader), leave=False, desc='train')

loss_list = []
for batch in train_loader:
    for k, v in batch.items():
        if torch.is_tensor(v):
            batch[k] = v.to(device)
    inp = batch['inp']
    gt = batch['gt']
    model.set_input(inp, gt)
    model.optimize_parameters()
    batch_loss = model.loss_G.item()
    loss_list.append(batch_loss)
    if pbar is not None:
        pbar.update(1)

if pbar is not None:
    pbar.close()

loss = mean(loss_list)
return loss

def main(config_, save_path): global config, log, writer, loginfo config = config log, writer = utils.set_save_path(save_path, remove=False) with open(os.path.join(save_path, 'config.yaml'), 'w') as f: yaml.dump(config, f, sort_keys=False)

train_loader, val_loader = make_data_loaders()

if config.get('data_norm') is None:
    config['data_norm'] = {
        'inp': {'sub': [0], 'div': [1]},
        'gt': {'sub': [0], 'div': [1]}
    }

model, optimizer, epoch_start, lr_scheduler = prepare_training()
model.optimizer = optimizer

if config.get('sam_checkpoint') is not None:
    sam_checkpoint = torch.load(config['sam_checkpoint'])
    model.load_state_dict(sam_checkpoint['model'], strict=False)

for name, para in model.named_parameters():
    if "image_encoder" in name and "prompt_generator" not in name:
        para.requires_grad_(False)

model_total_params = sum(p.numel() for p in model.parameters())
model_grad_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('model_grad_params:', model_grad_params, '\nmodel_total_params:', model_total_params)
log('model_grad_params: {}'.format(model_grad_params))

epoch_max = config['epoch_max']
epoch_val = config.get('epoch_val')
max_val_v = -1e18 if config['eval_type'] != 'ber' else 1e8
timer = utils.Timer()
best_iou = 0
for epoch in range(epoch_start, epoch_max + 1):
    t_epoch_start = timer.t()
    train_loss_G = train(train_loader, model)
    lr_scheduler.step()

    log_info = ['epoch {}/{}'.format(epoch, epoch_max)]
    writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)
    log_info.append('train G: loss={:.4f}'.format(train_loss_G))
    writer.add_scalars('loss', {'train G': train_loss_G}, epoch)

    model_spec = config['model']
    model_spec['sd'] = model.state_dict()
    optimizer_spec = config['optimizer']
    optimizer_spec['sd'] = optimizer.state_dict()

    save(config, model, save_path, 'last')

    if (epoch_val is not None) and (epoch % epoch_val == 0):
        result1, result2, result3, result4, metric1, metric2, metric3, metric4 = eval_psnr(val_loader, model,
                                                                                           eval_type=config.get(
                                                                                               'eval_type'))
        if result2 > best_iou:
            best_iou = result2
            print("current best iou is {} in epoch {}".format(best_iou, epoch))

        log_info.append('val: {}={:.4f}'.format(metric1, result1))
        writer.add_scalars(metric1, {'val': result1}, epoch)
        log_info.append('val: {}={:.4f}'.format(metric2, result2))
        writer.add_scalars(metric2, {'val': result2}, epoch)
        log_info.append('val: {}={:.4f}'.format(metric3, result3))
        writer.add_scalars(metric3, {'val': result3}, epoch)
        log_info.append('val: {}={:.4f}'.format(metric4, result4))
        writer.add_scalars(metric4, {'val': result4}, epoch)

        if config['eval_type'] != 'ber':
            if result1 > max_val_v:
                max_val_v = result1
                save(config, model, save_path, 'best')
        else:
            if result2 < max_val_v:
                max_val_v = result2
                save(config, model, save_path, 'best')

        t = timer.t()
        prog = (epoch - epoch_start + 1) / (epoch_max - epoch_start + 1)
        t_epoch = utils.time_text(t - t_epoch_start)
        t_elapsed, t_all = utils.time_text(t), utils.time_text(t / prog)
        log_info.append('{} {}/{}'.format(t_epoch, t_elapsed, t_all))

        log(', '.join(log_info))
        writer.flush()

def save(config, model, save_path, name): if config['model']['name'] == 'segformer' or config['model']['name'] == 'setr': if config['model']['args']['encoder_mode']['name'] == 'evp': prompt_generator = model.encoder.backbone.prompt_generator.state_dict() decode_head = model.encoder.decode_head.state_dict() torch.save({"prompt": prompt_generator, "decode_head": decode_head}, os.path.join(save_path, f"promptepoch{name}.pth")) else: torch.save(model.state_dict(), os.path.join(save_path, f"modelepoch{name}.pth")) else: torch.save(model.state_dict(), os.path.join(save_path, f"modelepoch{name}.pth"))

if name == 'main': parser = argparse.ArgumentParser() parser.add_argument('--config', default="configs/train/setr/train_setr_evp_cod.yaml") parser.add_argument('--name', default=None) parser.add_argument('--tag', default=None) args = parser.parse_args()

with open(args.config, 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
    print('config loaded.')

save_name = args.name
if save_name is None:
    save_name = '_' + args.config.split('/')[-1][:-len('.yaml')]
if args.tag is not None:
    save_name += '_' + args.tag
save_path = os.path.join('./save', save_name)

main(config, save_path)

简单改了一个,试试
chooooock commented 2 weeks ago

你好,请问你解决这个问题了吗?我也想了解

import argparse
import os

import yaml
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR

import datasets
import models
import utils
from statistics import mean
import torch
import torch.distributed as dist

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def make_data_loader(spec, tag=''):
    if spec is None:
        return None

    dataset = datasets.make(spec['dataset'])
    dataset = datasets.make(spec['wrapper'], args={'dataset': dataset})

    if __name__ == '__main__':
        log('{} dataset: size={}'.format(tag, len(dataset)))
        for k, v in dataset[0].items():
            if torch.is_tensor(v):
                log('  {}: shape={}'.format(k, tuple(v.shape)))

    loader = DataLoader(dataset, batch_size=spec['batch_size'],
                        shuffle=True, num_workers=8, pin_memory=True, drop_last=True)
    return loader

def make_data_loaders():
    train_loader = make_data_loader(config.get('train_dataset'), tag='train')
    val_loader = make_data_loader(config.get('val_dataset'), tag='val')
    return train_loader, val_loader

def eval_psnr(loader, model, eval_type=None):
    model.eval()

    if eval_type == 'f1':
        metric_fn = utils.calc_f1
        metric1, metric2, metric3, metric4 = 'f1', 'auc', 'none', 'none'
    elif eval_type == 'fmeasure':
        metric_fn = utils.calc_fmeasure
        metric1, metric2, metric3, metric4 = 'f_mea', 'mae', 'none', 'none'
    elif eval_type == 'ber':
        metric_fn = utils.calc_ber
        metric1, metric2, metric3, metric4 = 'shadow', 'non_shadow', 'ber', 'none'
    elif eval_type == 'cod':
        metric_fn = utils.calc_cod
        metric1, metric2, metric3, metric4 = 'sm', 'em', 'wfm', 'mae'
    elif eval_type == 'kvasir':
        metric_fn = utils.calc_kvasir
        metric1, metric2, metric3, metric4 = 'dice', 'iou', 'none', 'none'

    pbar = tqdm(total=len(loader), leave=False, desc='val')

    val_metric1 = 0
    val_metric2 = 0
    val_metric3 = 0
    val_metric4 = 0
    cnt = 0

    for batch in loader:
        for k, v in batch.items():
            if torch.is_tensor(v):
                batch[k] = v.cuda()

        inp = batch['inp']

        pred = torch.sigmoid(model.infer(inp))

        result1, result2, result3, result4 = metric_fn(pred, batch['gt'])
        val_metric1 += (result1 * pred.shape[0])
        val_metric2 += (result2 * pred.shape[0])
        val_metric3 += (result3 * pred.shape[0])
        val_metric4 += (result4 * pred.shape[0])
        cnt += pred.shape[0]
        if pbar is not None:
            pbar.update(1)

    if pbar is not None:
        pbar.close()

    return val_metric1 / cnt, val_metric2 / cnt, val_metric3 / cnt, val_metric4 / cnt, metric1, metric2, metric3, metric4

def prepare_training():
    if config.get('resume') is not None:
        model = models.make(config['model']).cuda()
        optimizer = utils.make_optimizer(
            model.parameters(), config['optimizer'])
        epoch_start = config.get('resume') + 1
    else:
        model = models.make(config['model']).cuda()
        optimizer = utils.make_optimizer(
            model.parameters(), config['optimizer'])
        epoch_start = 1
    max_epoch = config.get('epoch_max')
    lr_scheduler = CosineAnnealingLR(optimizer, max_epoch, eta_min=config.get('lr_min'))
    log('model: #params={}'.format(utils.compute_num_params(model, text=True)))
    return model, optimizer, epoch_start, lr_scheduler

def train(train_loader, model):
    model.train()

    pbar = tqdm(total=len(train_loader), leave=False, desc='train')

    loss_list = []
    for batch in train_loader:
        for k, v in batch.items():
            if torch.is_tensor(v):
                batch[k] = v.to(device)
        inp = batch['inp']
        gt = batch['gt']
        model.set_input(inp, gt)
        model.optimize_parameters()
        batch_loss = model.loss_G.item()
        loss_list.append(batch_loss)
        if pbar is not None:
            pbar.update(1)

    if pbar is not None:
        pbar.close()

    loss = mean(loss_list)
    return loss

def main(config_, save_path):
    global config, log, writer, log_info
    config = config_
    log, writer = utils.set_save_path(save_path, remove=False)
    with open(os.path.join(save_path, 'config.yaml'), 'w') as f:
        yaml.dump(config, f, sort_keys=False)

    train_loader, val_loader = make_data_loaders()

    if config.get('data_norm') is None:
        config['data_norm'] = {
            'inp': {'sub': [0], 'div': [1]},
            'gt': {'sub': [0], 'div': [1]}
        }

    model, optimizer, epoch_start, lr_scheduler = prepare_training()
    model.optimizer = optimizer

    if config.get('sam_checkpoint') is not None:
        sam_checkpoint = torch.load(config['sam_checkpoint'])
        model.load_state_dict(sam_checkpoint['model'], strict=False)

    for name, para in model.named_parameters():
        if "image_encoder" in name and "prompt_generator" not in name:
            para.requires_grad_(False)

    model_total_params = sum(p.numel() for p in model.parameters())
    model_grad_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('model_grad_params:', model_grad_params, '\nmodel_total_params:', model_total_params)
    log('model_grad_params: {}'.format(model_grad_params))

    epoch_max = config['epoch_max']
    epoch_val = config.get('epoch_val')
    max_val_v = -1e18 if config['eval_type'] != 'ber' else 1e8
    timer = utils.Timer()
    best_iou = 0
    for epoch in range(epoch_start, epoch_max + 1):
        t_epoch_start = timer.t()
        train_loss_G = train(train_loader, model)
        lr_scheduler.step()

        log_info = ['epoch {}/{}'.format(epoch, epoch_max)]
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)
        log_info.append('train G: loss={:.4f}'.format(train_loss_G))
        writer.add_scalars('loss', {'train G': train_loss_G}, epoch)

        model_spec = config['model']
        model_spec['sd'] = model.state_dict()
        optimizer_spec = config['optimizer']
        optimizer_spec['sd'] = optimizer.state_dict()

        save(config, model, save_path, 'last')

        if (epoch_val is not None) and (epoch % epoch_val == 0):
            result1, result2, result3, result4, metric1, metric2, metric3, metric4 = eval_psnr(val_loader, model,
                                                                                               eval_type=config.get(
                                                                                                   'eval_type'))
            if result2 > best_iou:
                best_iou = result2
                print("current best iou is {} in epoch {}".format(best_iou, epoch))

            log_info.append('val: {}={:.4f}'.format(metric1, result1))
            writer.add_scalars(metric1, {'val': result1}, epoch)
            log_info.append('val: {}={:.4f}'.format(metric2, result2))
            writer.add_scalars(metric2, {'val': result2}, epoch)
            log_info.append('val: {}={:.4f}'.format(metric3, result3))
            writer.add_scalars(metric3, {'val': result3}, epoch)
            log_info.append('val: {}={:.4f}'.format(metric4, result4))
            writer.add_scalars(metric4, {'val': result4}, epoch)

            if config['eval_type'] != 'ber':
                if result1 > max_val_v:
                    max_val_v = result1
                    save(config, model, save_path, 'best')
            else:
                if result2 < max_val_v:
                    max_val_v = result2
                    save(config, model, save_path, 'best')

            t = timer.t()
            prog = (epoch - epoch_start + 1) / (epoch_max - epoch_start + 1)
            t_epoch = utils.time_text(t - t_epoch_start)
            t_elapsed, t_all = utils.time_text(t), utils.time_text(t / prog)
            log_info.append('{} {}/{}'.format(t_epoch, t_elapsed, t_all))

            log(', '.join(log_info))
            writer.flush()

def save(config, model, save_path, name):
    if config['model']['name'] == 'segformer' or config['model']['name'] == 'setr':
        if config['model']['args']['encoder_mode']['name'] == 'evp':
            prompt_generator = model.encoder.backbone.prompt_generator.state_dict()
            decode_head = model.encoder.decode_head.state_dict()
            torch.save({"prompt": prompt_generator, "decode_head": decode_head},
                       os.path.join(save_path, f"prompt_epoch_{name}.pth"))
        else:
            torch.save(model.state_dict(), os.path.join(save_path, f"model_epoch_{name}.pth"))
    else:
        torch.save(model.state_dict(), os.path.join(save_path, f"model_epoch_{name}.pth"))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', default="configs/train/setr/train_setr_evp_cod.yaml")
    parser.add_argument('--name', default=None)
    parser.add_argument('--tag', default=None)
    args = parser.parse_args()

    with open(args.config, 'r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
        print('config loaded.')

    save_name = args.name
    if save_name is None:
        save_name = '_' + args.config.split('/')[-1][:-len('.yaml')]
    if args.tag is not None:
        save_name += '_' + args.tag
    save_path = os.path.join('./save', save_name)

    main(config, save_path)

简单改了一个,试试

非常感谢!