victoresque / pytorch-template

PyTorch deep learning projects made easy.
MIT License
4.7k stars 1.08k forks source link

Some features I have implemented. #88

Open deeperlearner opened 3 years ago

deeperlearner commented 3 years ago

I started to use this template since 2020 Oct. And I found that there are some features can be included to this template. I made them in my repo. I just list some notable features I have added:

Overview of config.json

{
    "n_gpu": 1,
    "root_dir": "./",
    "save_dir": "saved/",
    "name": "dataset_model",

    "datasets": {
        "train": {
            "data": {
                "module": ".data_loader",
                "type": "MyDataset",
                "kwargs": {
                    "data_dir": "./data",
                    "label_path": null,
                    "mode": "train"
                }
            }
        },
        "valid": {
        },
        "test": {
            "data": {
                "module": ".data_loader",
                "type": "MyDataset",
                "kwargs": {
                    "data_dir": "./data",
                    "label_path": null,
                    "mode": "test"
                }
            }
        }
    },
    "data_loaders": {
        "train": {
            "data": {
                "module": ".data_loader",
                "type": "BaseDataLoader",
                "kwargs": {
                    "validation_split": 0.2,
                    "DataLoader_kwargs": {
                        "batch_size": 64,
                        "shuffle": true,
                        "num_workers": 4
                    },
                    "do_transform": true
                }
            }
        },
        "valid": {
        },
        "test": {
            "data": {
                "module": ".data_loader",
                "type": "DataLoader",
                "kwargs": {
                    "batch_size": 64,
                    "shuffle": false,
                    "num_workers": 4
                },
                "do_transform": true
            }
        }
    },
    "models": {
        "model": {
            "module": ".model",
            "type": "MyModel"
        }
    },
    "losses": {
        "loss": {
            "type": "nll_loss"
        }
    },
    "metrics": {
        "per_iteration": ["accuracy"],
        "per_epoch": ["AUROC", "AUPRC"]
    },
    "optimizers": {
        "model": {
            "type": "Adam",
            "kwargs": {
                "lr": 0.001
            }
        }
    },
    "lr_schedulers": {
        "model": {
            "type": "StepLR",
            "kwargs": {
                "step_size": 50,
                "gamma": 0.1
            }
        }
    },
    "trainer": {
        "module": ".trainer",
        "type": "Trainer",
        "k_fold": 5,
        "fold_idx": 0,
        "kwargs": {
            "finetune": false,
            "epochs": 2,
            "len_epoch": null,

            "save_period": 5,
            "save_the_best": true,
            "verbosity": 2,

            "monitor": "max val_accuracy",
            "early_stop": 0,

            "tensorboard": false
        }
    }
}

Enable multiple instances in datasets, data_loaders, models, losses, optimizers, lr_schedulers

Multiple datasets like domain adaption training will use source dataset and target dataset, so do data_loaders. Multiple models like GAN. Generator and Discriminator. Multiple losses, optimizers, lr_schedulers can be found in many ML papers.

train/valid/test

If the path of train/valid/test is already given, then the content can directly put in the section in datasets, data_loaders.

module/type

When there are more than one module, for example,

AUROC/AUPRC

In metric part, I add two commonly used metric AUROC/AUPRC. These two metric need to compute on whole epoch, so the compute method is different from accuracy

MetricTracker

Continue from AUROC/AUPRC, I revise the MetricTracker, which is moved to model/metric.py. The MetricTracker can record both accuracy-like metric (metirc_iter) and AUROC-like (metric_epoch) metric.

Cross validation

cross validation are supported class Cross_Valid in base/base_dataloader.py can record each fold results (all metrics in MetricTracker). The model of each fold are saved. test.py can ensemble k-fold validation results.

Examples

I add some example codes to use the above features.

I'm appreciated if anyone has some comments on my work.

Thanks, Pei-Ying, Liu