Open Borda opened 2 years ago
That would be great! I use choices
quite a lot.
Please help Star
Great idea. We don't currently use type annotations in fire to impose restrictions (but we could in a future version, though no one is actively working toward it atm).
Side note: One alternative that works today is to use a decorator, roughly like this:
def restrict_choices(choices):
def decorator(f):
def new_f(x):
if x not in choices:
raise FireError("Invalid choice")
return f(x)
return new_f
return decorator
@restrict_choices(['left', 'right'])
def main(move):
print(f"Moving in given direction: {move}")
See also SetParseFns in https://github.com/google/python-fire/blob/master/fire/decorators.py
You might also find the HfArgumentParser
relevant: https://github.com/huggingface/transformers/blob/514de24abfd4416aeba6a6455ad5920f57f3567d/src/transformers/hf_argparser.py#L109
You might also find the
HfArgumentParser
relevant: https://github.com/huggingface/transformers/blob/514de24abfd4416aeba6a6455ad5920f57f3567d/src/transformers/hf_argparser.py#L109
Not really if you have to install full HF package for it...
You might also find the
HfArgumentParser
relevant: https://github.com/huggingface/transformers/blob/514de24abfd4416aeba6a6455ad5920f57f3567d/src/transformers/hf_argparser.py#L109Not really if you have to install full HF package for it...
The alternative below doesn't need the HF package. It is simple and readable but creates the Config
object twice.
from pydantic import BaseModel
class Config(BaseModel):
...
def main(**kwargs):
config = Config().model_copy(update=kwargs)
if __name__ == "__main__":
fire.Fire(main)
Thanks! Here is a more generalized version with the help of GPT. The problem is that if sig.bind
raises an TypeError Exception, fire won't work.
from typing import Union, List, Any
import inspect
from fire.core import FireError
def restrict_choices(arg_name_or_position: Union[int, str], choices: List[Any]):
def decorator(f):
sig = inspect.signature(f) # Get the function signature
def new_f(*args, **kwargs):
# Map arguments by position and name
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults() # Handle any default arguments
# Determine if we're restricting by name or position
if isinstance(arg_name_or_position, str):
# Restrict by argument name
if arg_name_or_position in bound_args.arguments:
restricted_arg = bound_args.arguments[arg_name_or_position]
arg_identifier = f"argument '{arg_name_or_position}'"
else:
raise FireError(
f"Argument '{arg_name_or_position}' not found")
elif isinstance(arg_name_or_position, int):
# Restrict by argument position
if arg_name_or_position < len(bound_args.args):
restricted_arg = bound_args.args[arg_name_or_position]
arg_identifier = f"position {arg_name_or_position}"
else:
raise FireError(
f"Argument position {arg_name_or_position} is out of range")
else:
raise FireError(
"Invalid argument specifier, must be a name (str) or position (int)")
# Check if the restricted argument is in the allowed choices
if restricted_arg not in choices:
raise FireError(
f"Invalid choice '{restricted_arg}' for {arg_identifier}. "
f"Valid choices are: {choices}")
# Call the original function if the check passes
return f(*args, **kwargs)
return new_f
return decorator
Examples:
@restrict_choices('direction', ['left', 'right'])
def move(direction, speed):
print(f"Moving {direction} at speed {speed}")
@restrict_choices(1, ['left', 'right'])
def move(speed, direction):
print(f"Moving {direction} at speed {speed}")
Hello, and thank you for this great CLI! Recently I get to a situation when I would like to restrict the options for a given argument similar to build-in
argparse
does with its optionchoices
(see docs: https://docs.python.org/3/library/argparse.html#choices). Then I was checking Fire docs but could not find anything similar to it... Checking alternative CLI packages I found a way that is quite simple but still elegant and would well fit the Fire style. It is leveraging pythonEnum
class:For clarification, the example above is borrowed and adjusted from Typer/enum