Closed mvidela31 closed 2 years ago
This url is a Google Colab tutorial for MMGeneration and we will merge these to our repo later.
To train model with python API, you can use the following code:
import torch
from mmcv import Config
from mmgen.core.ddp_wrapper import DistributedDataParallelWrapper
from mmgen.core.optimizer import build_optimizers
from mmgen.models import build_model
from mmgen.datasets import build_dataloader, build_dataset
import os
import torch
import torch.distributed as dist
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def main():
config_path = ('./configs/sagan/sagan_32_wReLUinplace_'
'lr-2e-4_ndisc5_cifar10_b64x1.py')
cfg = Config.fromfile(config_path)
model = build_model(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
find_unused_parameters = cfg.get('find_unused_parameters', False)
# for static training
optimizer = build_optimizers(model, cfg.optimizer)
model_ddp = DistributedDataParallelWrapper(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
datasets = [build_dataset(cfg.data.train)]
loader_cfg = dict(
samples_per_gpu=cfg.data.samples_per_gpu,
workers_per_gpu=cfg.data.workers_per_gpu,
# cfg.gpus will be ignored if distributed
num_gpus=1,
dist=True,
persistent_workers=cfg.data.get('persistent_workers', False),
seed=42)
train_loader_cfg = {**loader_cfg, **cfg.data.get('train_dataloader', {})}
data_loaders = [
build_dataloader(ds, **train_loader_cfg) for ds in datasets
]
train_loader = data_loaders[0]
for idx, data_batch in enumerate(train_loader):
outputs = model_ddp.train_step(data_batch, optimizer)
if idx == 5:
break
if __name__ == '__main__':
setup(0, 1)
main()
cleanup()
Hi @LeoXing1996, thanks for the quick answer!
I tried your solution but it seems that the model artifacts (checkpoint and logs files) are not automatically saved at the cfg.work_dir
directory. I would like to train the model as follows:
import wget
import torch
import mlflow
import os.path as osp
import mmcv
from mmcv import Config
from mmgen import __version__
from mmgen.apis import train_model
from mmgen.datasets import build_dataset
from mmgen.models import build_model
# Dir paths
dataset_dir = '/content/dataset'
work_dir = '/content/work_dir'
ckpt_path = '/content/checkpoint'
# Update configuration file
config_file = '/mmgen/configs/styleganv2/stylegan2_c2_lsun-car_384x512_b4x8.py'
cfg = Config.fromfile(config_file)
cfg.log_config.hooks = [dict(type='TextLoggerHook', by_epoch=False),
dict(type='MlflowLoggerHook', log_model=False)]
aug_kwargs = dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1,
xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1,
saturation=1) # BGC augmentation pipeline
cfg.model.discriminator = dict(type='ADAStyleGAN2Discriminator',
in_size=512,
data_aug=dict(type='ADAAug',
aug_pipeline=aug_kwargs,
ada_kimg=100))
cfg.data.train.dataset.imgs_root = dataset_dir
cfg.data.val.imgs_root = dataset_dir
cfg.work_dir = work_dir
cfg.gpu_id = 0
cfg.gpu_ids = range(1)
cfg.seed = 2021
if torch.cuda.is_available():
cfg.device = 'cuda:0'
else:
cfg.device = 'cpu'
# Download a pre-trained model
mmcv.mkdir_or_exist(osp.abspath(ckpt_path))
checkpoint_url = 'https://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-car-config-f-official_20210327_172340-8cfe053c.pth'
cfg.load_from = wget.download(checkpoint_url, out=ckpt_path)
# Build the model
model = build_model(cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
# Build the dataset
datasets = [build_dataset(cfg.data.train)]
# Create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# Train the model ('MMDataParallel' object has no attribute 'reducer')
train_model(model, datasets, cfg, distributed=False, validate=False, meta=dict())
Could you please show me how to get the mentioned artifacts (saved at cfg.work_dir
) using the Python API?
You should add CheckpointHook
and work_dirs
in cfg at the start of your code.
checkpoint_config = dict(interval=5000, by_epoch=False, max_keep_ckpts=10)
cfg.checkpoint_config = checkpoint_config
cfg.work_dir = 'my-work-space'
Thanks for replying!
Hi everybody,
I need to train an image generation model using a Python script similar to this tutorial from the MMSegmentation library. I tried to replicate that example using the code in the
train.py
file, but it fails when trying to run thetrain_model
function with following error:'MMDataParallel' object has no attribute 'reducer'
. At issue #26, the devs recommend executingdist_train.sh
instead oftrain.py
file, but this solution uses abash
command, and I need to use only Python code.Is there a way to train an MMGeneration model using just a Python script? Besides, it would be awesome to have a tutorial example on Google Colaboratory like the one from the MMSegmentation library.
Many thanks.