Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.33k stars 3.38k forks source link

Refactor Trainer.add_argparse_args #7020

Closed tayden closed 3 years ago

tayden commented 3 years ago

🚀 Feature

The add_argparse_args method in the Trainer class creates a new parser using the parent_parser that is passed as a parameter. I propose that instead, the new parser arguments are appended to the existing parser under a "trainer args" group. This is the same method recommended in the hyperparameter docs about best practices when working with argparse args.

Motivation

The current method that creates a new parser, with an existing parser as it's parent, is that it doesn't work well when a subparsers is passed to Train.add_argparse_args. In fact, the Trainer args aren't properly appended at all.

For example, I would expect the following to work.

from argparse import ArgumentParser

import pytorch_lightning as pl

from models.deeplabv3 import DeepLabv3 as model

def cli_main():
    parser = ArgumentParser()
    subparsers = parser.add_subparsers()

    # Define args when module called with first param is 'train'
    parser_train = subparsers.add_parser(name='train', help="train the model")
    parser_train.add_argument('data_dir', type=str)
    parser_train.add_argument('checkpoint_dir', type=str)

    # WORKS, follows best practices in lightning docs
    parser_train = model.add_argparse_args(parser_train)
    # DOESN'T WORK
    parser_train = pl.Trainer.add_argparse_args(parser_train)

    # Calls train(args) when parser.parse_args() is called
    parser_train.set_defaults(func=train)

    # Define args when module called with first param is 'pred'
    parser_pred = subparsers.add_parser(name='pred', help='predict on an image')
    parser_pred.add_argument('image_path', type=str)
    parser_pred.add_argument('weights_path', type=str)

    # Calls train(args) when parser.parse_args() is called
    parser_pred.set_defaults(func=pred)

    parser.parse_args()

def pred(args):
    # ...do something with prediction specific args
    pass

def train(args):
    # ...do something with training specific args
    pass

if __name__ == '__main__':
    cli_main()

Unfortunately, all the trainer args don't make it into the parser_train.

Pitch

Update Trainer.add_argparse_args to add args under a new group instead of creating a new parser.

In pytorch-lightning/blob/master/pytorch_lightning/utilities/argparse.py

def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
    parser = ArgumentParser(
        parents=[parent_parser],
        add_help=False,
    )
    #...
    return parser

becomes

def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
    group = parent_parser.add_argument_group('Trainer')
    # ... add args to group
    return parent_parser

Additional context

I realize it's possible to work around this with parser.parse_known_args(), but I think the way I'm trying to do it is much nicer, and produces better CLI help messages.

I'm happy to open a PR for this if it seems acceptable.

tayden commented 3 years ago

I can see this is already implemented on the Master branch, so never mind! Thanks!