Closed awaelchli closed 2 years ago
I'm assuming many tests could also disable these flags for faster runs and less verbose logs: Trainer(enable_checkpointing=False, logger=False, enable_model_summary=False)
We could write an AST parser that looks for Trainer
in our tests and adds these automatically, pre-commit will fix the formatting later. Or even a simple python script looking for Trainer()
would work.
Then we remove the arguments for the failing tests that require different behavior.
Here's a working script
from typing import Any, Dict, List, Tuple, Union
def find_trainers(code: str) -> List[Tuple[int, int]]:
trainer = "Trainer("
t = len(trainer)
i = 0
arg_start = None
parentheses = 0
n = len(code)
found = []
while i < n - t:
if code[i : i + t] == trainer:
# Trainer found
i += t
arg_start = i
parentheses += 1
elif arg_start is not None:
# keep a stack of parentheses used
if code[i] == "(":
parentheses += 1
elif code[i] == ")":
parentheses -= 1
# closed Trainer instantiation
if parentheses == 0:
found.append((arg_start, i))
arg_start = None
i += 1
else:
i += 1
parentheses = 0
return found
def add_arguments(code: str, indices: List[Tuple[int, int]], changes: Dict[str, Any]) -> str:
# loop in reverse because additions will mess up the indices
for start, end in reversed(indices):
# parsing comments would be too difficult. assume there will be a trailing comma if it has a comment
if "#" not in code[start:end]:
# add a trailing comma if necessary
for i in range(end - 1, start - 1, -1):
if not code[i].isspace():
if code[i] != ",":
code = code[:end] + "," + code[end:]
end += 1
break
args = code[start:end]
if not (
"limit_train_batches" in args
or "limit_val_batches" in args
or "limit_test_batches" in args
or "limit_predict_batches" in args
or "max_epochs" in args
or "max_steps" in args
or "max_time" in args
):
# heuristic: filter out trainers that will not run. a better solution would be to check that this trainer
# instance will run `.fit`, `.test`, ... but that'd be much more complex
continue
# add the arguments
for arg, value in changes.items():
if arg in args:
# already set
continue
if arg == "enable_checkpointing" and ("ModelCheckpoint" in args or "checkpoint" in args or "ckpt" in args):
# heuristic: ModelCheckpoint was passed
continue
if arg == "enable_model_summary" and "ModelSummary" in args:
# heuristic: ModelSummary was passed
continue
if arg == "enable_progress_bar" and ("ProgressBar" in args or "progress_bar" in args or "pbar" in args):
# heuristic: ModelSummary was passed
continue
code = code[:end] + f"{arg}={value}," + code[end:]
return code
def format(code: str) -> str:
import black
try:
return black.format_str(code, mode=black.Mode(line_length=120, magic_trailing_comma=False))
except black.parsing.InvalidInput:
print(code)
raise
def main(code: str) -> str:
indices = find_trainers(code)
code = add_arguments(
code,
indices,
{"logger": False, "enable_checkpointing": False, "enable_model_summary": False, "enable_progress_bar": False},
)
code = format(code)
return code
def update_file(filepath: str) -> None:
with open(filepath) as fp:
code = fp.read()
try:
code = main(code)
except:
print("Problem parsing", filepath)
raise
with open(filepath, "w") as fp:
fp.write(code)
if __name__ == "__main__":
import os
import jsonargparse
from jsonargparse.typing import Path_drw, Path_dw
parser = jsonargparse.ArgumentParser()
parser.add_argument("path", type=Union[Path_dw, Path_drw], default="tests")
args = parser.parse_args()
for root, dirs, files in os.walk(args.path):
for file in files:
if not file.endswith(".py"):
continue
update_file(os.path.join(root, file))
Comparing the runtime of the CPU Conda tests for #11113 and one of masters' commits, it doesn't seem like this provides any substantial speedup at all. We would need to compare N runs and take averages to properly see the magnitude.
Proposed refactor
The output produced by the standalone tests is unfiltered and produces 1000's of lines due to the progress bar being turned on. This is not needed for 99% - 100% of tests. For regular tests this is not an issue since the tests get batched and the output is not printed in verbose.
Motivation
Easier to read and scroll the test output logs
Pitch
Set
Trainer(enable_progress_bar=False)
for all tests that have@RunIf(standalone=True)
as the marker. Additionally, could also turn off the model summary.
Additional context
Happy holidays
If you enjoy Lightning, check out our other projects! ⚡
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.
Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.
Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra. ``
cc @borda @justusschock @awaelchli @akihironitta