pytorch / torchdynamo

A Python-level JIT compiler designed to make unmodified PyTorch programs faster.
BSD 3-Clause "New" or "Revised" License
1.01k stars 124 forks source link

Direct calls to logging are suppressed inside benchmarks/dynamo (and possibly other situations) #1981

Closed ezyang closed 1 year ago

ezyang commented 1 year ago

🐛 Describe the bug

Steps to reproduce:

somewhere in the dynamo codebase, add a log like logging.debug("blah blah") and set dynamo log level to DEBUG in torch._dynamo.config. In my case, I was testing with this:

diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py
index 048e0d3fa3..8d38faa995 100644
--- a/torch/_dynamo/output_graph.py
+++ b/torch/_dynamo/output_graph.py
@@ -665,6 +670,7 @@ class OutputGraph(fx.Tracer):

         for node, arg in list(zip(self.graph.nodes, expanded_graphargs)):
             if arg.uses == 0:
+                logging.debug(f"REMOVE UNUSED GRAPHARG {arg.source.name()}")
                 if "example_value" in node.meta:
                     del node.meta["example_value"]
                 self.remove_node(node)

and running TORCHDYNAMO_DYNAMIC_SHAPES=1 AOT_DYNAMIC_SHAPES=1 python benchmarks/dynamo/torchbench.py --accuracy --backend aot_eager --training --only hf_Reformer

but I'm pretty sure these repros in simpler situations.

What you will find is that the log message never shows up anywhere. This is because you aren't supposed to use logging directly; instead, you have to use the logger for the module at log. This is very easy to get wrong. At the very least we should have a lint for this.

I'd also love to know, from someone who actually understands how Python logging works, how this is actually supposed to work.

Error logs

No response

Minified repro

No response

jansel commented 1 year ago

Could you try logging.getLogger(__name__).debug(f"REMOVE UNUSED GRAPHARG {arg.source.name()}")?

logging.debug() goes to the python global root logger. Setting the dynamo log level intentionally doesn't touch the root logger because that could interfere with user code that also uses logging.

ezyang commented 1 year ago

Yeah, this is how I fixed it. Based on https://docs.python.org/3/howto/logging.html#configuring-logging-for-a-library it sounds like we need a lint rule that twings when you directly call logging.

ezyang commented 1 year ago

Actually, it would also be nice if we had a runtime error if you used the base logger from our library code, because a lint isn't going to help much if you insert logging.debug for debugging purposes, and then scratch your head "why didn't it show up"

jansel commented 1 year ago

Both idea seem reasonable, we should not pollute the global logging namespace.

williamwen42 commented 1 year ago

Done by https://github.com/pytorch/pytorch/pull/90907