open-mmlab / mmgeneration

MMGeneration is a powerful toolkit for generative models, based on PyTorch and MMCV.
https://mmgeneration.readthedocs.io/en/latest/
Apache License 2.0
1.91k stars 230 forks source link

Training using Python API #337

Closed mvidela31 closed 2 years ago

mvidela31 commented 2 years ago

Hi everybody,

I need to train an image generation model using a Python script similar to this tutorial from the MMSegmentation library. I tried to replicate that example using the code in the train.py file, but it fails when trying to run the train_model function with following error: 'MMDataParallel' object has no attribute 'reducer'. At issue #26, the devs recommend executing dist_train.sh instead of train.py file, but this solution uses a bash command, and I need to use only Python code.

Is there a way to train an MMGeneration model using just a Python script? Besides, it would be awesome to have a tutorial example on Google Colaboratory like the one from the MMSegmentation library.

Many thanks.

LeoXing1996 commented 2 years ago

This url is a Google Colab tutorial for MMGeneration and we will merge these to our repo later.

To train model with python API, you can use the following code:

import torch
from mmcv import Config
from mmgen.core.ddp_wrapper import DistributedDataParallelWrapper
from mmgen.core.optimizer import build_optimizers
from mmgen.models import build_model
from mmgen.datasets import build_dataloader, build_dataset

import os
import torch
import torch.distributed as dist

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def main():
    config_path = ('./configs/sagan/sagan_32_wReLUinplace_'
                   'lr-2e-4_ndisc5_cifar10_b64x1.py')

    cfg = Config.fromfile(config_path)
    model = build_model(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
    find_unused_parameters = cfg.get('find_unused_parameters', False)
    # for static training
    optimizer = build_optimizers(model, cfg.optimizer)

    model_ddp = DistributedDataParallelWrapper(
        model.cuda(),
        device_ids=[torch.cuda.current_device()],
        broadcast_buffers=False,
        find_unused_parameters=find_unused_parameters)

    datasets = [build_dataset(cfg.data.train)]
    loader_cfg = dict(
        samples_per_gpu=cfg.data.samples_per_gpu,
        workers_per_gpu=cfg.data.workers_per_gpu,
        # cfg.gpus will be ignored if distributed
        num_gpus=1,
        dist=True,
        persistent_workers=cfg.data.get('persistent_workers', False),
        seed=42)
    train_loader_cfg = {**loader_cfg, **cfg.data.get('train_dataloader', {})}
    data_loaders = [
        build_dataloader(ds, **train_loader_cfg) for ds in datasets
    ]
    train_loader = data_loaders[0]

    for idx, data_batch in enumerate(train_loader):
        outputs = model_ddp.train_step(data_batch, optimizer)
        if idx == 5:
            break

if __name__ == '__main__':
    setup(0, 1)
    main()
    cleanup()
mvidela31 commented 2 years ago

Hi @LeoXing1996, thanks for the quick answer!

I tried your solution but it seems that the model artifacts (checkpoint and logs files) are not automatically saved at the cfg.work_dir directory. I would like to train the model as follows:

import wget
import torch
import mlflow
import os.path as osp
import mmcv
from mmcv import Config
from mmgen import __version__
from mmgen.apis import train_model
from mmgen.datasets import build_dataset
from mmgen.models import build_model

# Dir paths
dataset_dir = '/content/dataset'
work_dir = '/content/work_dir'
ckpt_path = '/content/checkpoint'

# Update configuration file
config_file = '/mmgen/configs/styleganv2/stylegan2_c2_lsun-car_384x512_b4x8.py'
cfg = Config.fromfile(config_file)
cfg.log_config.hooks = [dict(type='TextLoggerHook', by_epoch=False), 
                        dict(type='MlflowLoggerHook', log_model=False)]
aug_kwargs = dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, 
                  xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, 
                  saturation=1) # BGC augmentation pipeline
cfg.model.discriminator = dict(type='ADAStyleGAN2Discriminator',
                               in_size=512,
                               data_aug=dict(type='ADAAug', 
                                             aug_pipeline=aug_kwargs, 
                                             ada_kimg=100))
cfg.data.train.dataset.imgs_root = dataset_dir
cfg.data.val.imgs_root = dataset_dir
cfg.work_dir = work_dir
cfg.gpu_id = 0
cfg.gpu_ids = range(1)
cfg.seed = 2021
if torch.cuda.is_available():
    cfg.device = 'cuda:0'
else:
    cfg.device = 'cpu'

# Download a pre-trained model
mmcv.mkdir_or_exist(osp.abspath(ckpt_path))
checkpoint_url = 'https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-car-config-f-official_20210327_172340-8cfe053c.pth'
cfg.load_from = wget.download(checkpoint_url, out=ckpt_path)

# Build the model
model = build_model(cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)

# Build the dataset
datasets = [build_dataset(cfg.data.train)]

# Create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))

# Train the model ('MMDataParallel' object has no attribute 'reducer')
train_model(model, datasets, cfg, distributed=False, validate=False, meta=dict())

Could you please show me how to get the mentioned artifacts (saved at cfg.work_dir) using the Python API?

LeoXing1996 commented 2 years ago

You should add CheckpointHook and work_dirs in cfg at the start of your code.

checkpoint_config = dict(interval=5000, by_epoch=False, max_keep_ckpts=10)
cfg.checkpoint_config = checkpoint_config

cfg.work_dir = 'my-work-space'
mvidela31 commented 2 years ago

Thanks for replying!