Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.47k stars 3.39k forks source link

Set `enable_progress_bar=False` in standalone tests #11083

Closed awaelchli closed 2 years ago

awaelchli commented 2 years ago

Proposed refactor

The output produced by the standalone tests is unfiltered and produces 1000's of lines due to the progress bar being turned on. This is not needed for 99% - 100% of tests. For regular tests this is not an issue since the tests get batched and the output is not printed in verbose.

Motivation

Easier to read and scroll the test output logs

Pitch

Set

Trainer(enable_progress_bar=False) for all tests that have

@RunIf(standalone=True)

as the marker. Additionally, could also turn off the model summary.

Additional context

Happy holidays


If you enjoy Lightning, check out our other projects! ⚡

cc @borda @justusschock @awaelchli @akihironitta

ananthsub commented 2 years ago

I'm assuming many tests could also disable these flags for faster runs and less verbose logs: Trainer(enable_checkpointing=False, logger=False, enable_model_summary=False)

carmocca commented 2 years ago

We could write an AST parser that looks for Trainer in our tests and adds these automatically, pre-commit will fix the formatting later. Or even a simple python script looking for Trainer() would work.

Then we remove the arguments for the failing tests that require different behavior.

carmocca commented 2 years ago

Here's a working script

from typing import Any, Dict, List, Tuple, Union

def find_trainers(code: str) -> List[Tuple[int, int]]:
    trainer = "Trainer("
    t = len(trainer)
    i = 0
    arg_start = None
    parentheses = 0
    n = len(code)
    found = []

    while i < n - t:
        if code[i : i + t] == trainer:
            # Trainer found
            i += t
            arg_start = i
            parentheses += 1
        elif arg_start is not None:
            # keep a stack of parentheses used
            if code[i] == "(":
                parentheses += 1
            elif code[i] == ")":
                parentheses -= 1
                # closed Trainer instantiation
                if parentheses == 0:
                    found.append((arg_start, i))
                    arg_start = None
            i += 1
        else:
            i += 1
            parentheses = 0

    return found

def add_arguments(code: str, indices: List[Tuple[int, int]], changes: Dict[str, Any]) -> str:
    # loop in reverse because additions will mess up the indices
    for start, end in reversed(indices):
        # parsing comments would be too difficult. assume there will be a trailing comma if it has a comment
        if "#" not in code[start:end]:
            # add a trailing comma if necessary
            for i in range(end - 1, start - 1, -1):
                if not code[i].isspace():
                    if code[i] != ",":
                        code = code[:end] + "," + code[end:]
                        end += 1
                    break

        args = code[start:end]

        if not (
            "limit_train_batches" in args
            or "limit_val_batches" in args
            or "limit_test_batches" in args
            or "limit_predict_batches" in args
            or "max_epochs" in args
            or "max_steps" in args
            or "max_time" in args
        ):
            # heuristic: filter out trainers that will not run. a better solution would be to check that this trainer
            # instance will run `.fit`, `.test`, ... but that'd be much more complex
            continue

        # add the arguments
        for arg, value in changes.items():
            if arg in args:
                # already set
                continue
            if arg == "enable_checkpointing" and ("ModelCheckpoint" in args or "checkpoint" in args or "ckpt" in args):
                # heuristic: ModelCheckpoint was passed
                continue
            if arg == "enable_model_summary" and "ModelSummary" in args:
                # heuristic: ModelSummary was passed
                continue
            if arg == "enable_progress_bar" and ("ProgressBar" in args or "progress_bar" in args or "pbar" in args):
                # heuristic: ModelSummary was passed
                continue

            code = code[:end] + f"{arg}={value}," + code[end:]

    return code

def format(code: str) -> str:
    import black

    try:
        return black.format_str(code, mode=black.Mode(line_length=120, magic_trailing_comma=False))
    except black.parsing.InvalidInput:
        print(code)
        raise

def main(code: str) -> str:
    indices = find_trainers(code)
    code = add_arguments(
        code,
        indices,
        {"logger": False, "enable_checkpointing": False, "enable_model_summary": False, "enable_progress_bar": False},
    )
    code = format(code)
    return code

def update_file(filepath: str) -> None:
    with open(filepath) as fp:
        code = fp.read()

    try:
        code = main(code)
    except:
        print("Problem parsing", filepath)
        raise

    with open(filepath, "w") as fp:
        fp.write(code)

if __name__ == "__main__":
    import os

    import jsonargparse
    from jsonargparse.typing import Path_drw, Path_dw

    parser = jsonargparse.ArgumentParser()
    parser.add_argument("path", type=Union[Path_dw, Path_drw], default="tests")
    args = parser.parse_args()

    for root, dirs, files in os.walk(args.path):
        for file in files:
            if not file.endswith(".py"):
                continue
            update_file(os.path.join(root, file))
carmocca commented 2 years ago

Comparing the runtime of the CPU Conda tests for #11113 and one of masters' commits, it doesn't seem like this provides any substantial speedup at all. We would need to compare N runs and take averages to properly see the magnitude.