The add_argparse_args method in the Trainer class creates a new parser using the parent_parser that is passed as a parameter. I propose that instead, the new parser arguments are appended to the existing parser under a "trainer args" group. This is the same method recommended in the hyperparameter docs about best practices when working with argparse args.
Motivation
The current method that creates a new parser, with an existing parser as it's parent, is that it doesn't work well when a subparsers is passed to Train.add_argparse_args. In fact, the Trainer args aren't properly appended at all.
For example, I would expect the following to work.
from argparse import ArgumentParser
import pytorch_lightning as pl
from models.deeplabv3 import DeepLabv3 as model
def cli_main():
parser = ArgumentParser()
subparsers = parser.add_subparsers()
# Define args when module called with first param is 'train'
parser_train = subparsers.add_parser(name='train', help="train the model")
parser_train.add_argument('data_dir', type=str)
parser_train.add_argument('checkpoint_dir', type=str)
# WORKS, follows best practices in lightning docs
parser_train = model.add_argparse_args(parser_train)
# DOESN'T WORK
parser_train = pl.Trainer.add_argparse_args(parser_train)
# Calls train(args) when parser.parse_args() is called
parser_train.set_defaults(func=train)
# Define args when module called with first param is 'pred'
parser_pred = subparsers.add_parser(name='pred', help='predict on an image')
parser_pred.add_argument('image_path', type=str)
parser_pred.add_argument('weights_path', type=str)
# Calls train(args) when parser.parse_args() is called
parser_pred.set_defaults(func=pred)
parser.parse_args()
def pred(args):
# ...do something with prediction specific args
pass
def train(args):
# ...do something with training specific args
pass
if __name__ == '__main__':
cli_main()
Unfortunately, all the trainer args don't make it into the parser_train.
Pitch
Update Trainer.add_argparse_args to add args under a new group instead of creating a new parser.
In pytorch-lightning/blob/master/pytorch_lightning/utilities/argparse.py
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
group = parent_parser.add_argument_group('Trainer')
# ... add args to group
return parent_parser
Additional context
I realize it's possible to work around this with parser.parse_known_args(), but I think the way I'm trying to do it is much nicer, and produces better CLI help messages.
I'm happy to open a PR for this if it seems acceptable.
🚀 Feature
The add_argparse_args method in the Trainer class creates a new parser using the parent_parser that is passed as a parameter. I propose that instead, the new parser arguments are appended to the existing parser under a "trainer args" group. This is the same method recommended in the hyperparameter docs about best practices when working with argparse args.
Motivation
The current method that creates a new parser, with an existing parser as it's parent, is that it doesn't work well when a subparsers is passed to Train.add_argparse_args. In fact, the Trainer args aren't properly appended at all.
For example, I would expect the following to work.
Unfortunately, all the trainer args don't make it into the parser_train.
Pitch
Update Trainer.add_argparse_args to add args under a new group instead of creating a new parser.
In
pytorch-lightning/blob/master/pytorch_lightning/utilities/argparse.py
becomes
Additional context
I realize it's possible to work around this with
parser.parse_known_args()
, but I think the way I'm trying to do it is much nicer, and produces better CLI help messages.I'm happy to open a PR for this if it seems acceptable.