google / python-fire

Python Fire is a library for automatically generating command line interfaces (CLIs) from absolutely any Python object.
Other
26.86k stars 1.44k forks source link

Unexpected behaviour when using wrapped / decorated functions - cant supply arguments added #451

Closed jamesowers-roo closed 1 year ago

jamesowers-roo commented 1 year ago

Here's a simple example - say we want to wrap a function to set a logging level. We want to use a wrapper to apply this to every function within a module and remove the boilerplate.

test_script.py:

import functools
import logging

import fire

logging.basicConfig()
LOGGER = logging.getLogger(__name__)

def set_logging_level(func):
    """
    A decorator that sets the logging level for the decorated function.
    The level should be a string (e.g. 'DEBUG', 'INFO', etc.).
    """

    @functools.wraps(func)
    def wrapper(*args, logging_level="INFO", **kwargs):
        LOGGER.info(f"Setting logging level to {logging_level}")
        LOGGER.setLevel(logging_level)
        return func(*args, **kwargs)

    return wrapper

@set_logging_level
def test_function(a=1):
    """Docstring"""
    LOGGER.info("inside test_function")
    LOGGER.debug(f"{a=}")

if __name__ == "__main__":
    fire.Fire(test_function)

The function test_function itself behaves as expected, i.e. it's correctly wrapped

>>> test_function.__doc__
Docstring

>>> test_function()
INFO:__main__:Setting logging level to INFO
INFO:__main__:inside test_function

>>> test_function(logging_level="DEBUG")
INFO:__main__:Setting logging level to DEBUG
INFO:__main__:inside test_function
DEBUG:__main__:a=1

But calling with python-fire fails:

python test_script.py --logging-level DEBUG
INFO:__main__:inside test_function
ERROR: Could not consume arg: --logging-level
Usage: test_script.py -

For detailed information on this command, run:
  test_script.py - --help

Am I misunderstanding how python-fire should behave with wrapped functions?

jamesowers-roo commented 1 year ago

I am very happy to provide a PR to fix this. If someone could help get me started and point to the right files to edit, I'd be very grateful.

jamesowers-roo commented 1 year ago

This is not an issue with fire but rather one with changing a function's call signature with a decorator, and not updating the signature via the wrapper: fire correctly complains that logging_level isn't an argument of test_function.

To fix, we need to add code inside the decorator which updates the wrapped function's call signature, for example:

    wrapper_signature = inspect.signature(func)
    parameters = list(wrapper_signature.parameters.values())
    parameters.append(
        inspect.Parameter(
            "logging_level",
            inspect.Parameter.KEYWORD_ONLY,
            default=DEFAULT_LOGGING_LEVEL,
        )
    )
    wrapper_signature = wrapper_signature.replace(parameters=parameters)
    functools.update_wrapper(wrapper, func)
    wrapper.__signature__ = wrapper_signature

So a full solution which works as expected is:

import functools
import inspect
import logging

import fire

logging.basicConfig()
LOGGER = logging.getLogger(__name__)

def set_logging_level(func):
    """
    A decorator that sets the logging level for the decorated function.
    The level should be a string (e.g. 'DEBUG', 'INFO', etc.).
    """
    DEFAULT_LOGGING_LEVEL = "INFO"
    @functools.wraps(func)
    def wrapper(*args, logging_level=DEFAULT_LOGGING_LEVEL, **kwargs):
        LOGGER.info(f"Setting logging level to {logging_level}")
        LOGGER.setLevel(logging_level)
        return func(*args, **kwargs)

    wrapper_signature = inspect.signature(func)
    parameters = list(wrapper_signature.parameters.values())
    parameters.append(
        inspect.Parameter(
            "logging_level",
            inspect.Parameter.KEYWORD_ONLY,
            default=DEFAULT_LOGGING_LEVEL,
        )
    )
    wrapper_signature = wrapper_signature.replace(parameters=parameters)
    functools.update_wrapper(wrapper, func)
    wrapper.__signature__ = wrapper_signature

    return wrapper

@set_logging_level
def test_function(a=1):
    """Docstring"""
    LOGGER.info("inside test_function")
    LOGGER.debug(f"{a=}")

if __name__ == "__main__":
    fire.Fire(test_function)