Closed JesseFarebro closed 1 year ago
Thanks for giving the library a try!
For generating BoundArguments, I think something like this should work:
import functools
import inspect
import tyro
def func(a: int, b: int):
print(a + b)
bind_args = functools.wraps(func)(
lambda *args, **kwargs: inspect.signature(func).bind(*args, **kwargs)
)
args = tyro.cli(bind_args)
print(args)
Ideally we could just call tyro.cli()
on signature(func).bind
, but unfortunately this doesn't return a function with an inspectable signature. Calling wraps()
without the redundant-looking lambda function also results in an error.
Sure, I'm open to adding this. To match argparse
naming, maybe could be a return_unknown_args: bool
flag in tyro.cli()
that results in an extra output? May also want to consider a flag for turning off or renaming the --help
flag. PRs would be appreciated here if you have time.
In the meantime for JAX config stuff it seems like just adding an extra field to the callable that tyro
takes and then calling the usual from jax.config import config; config.update(...)
would result in cleaner error messages.
This could even be a dict:
import tyro
def train(
# train_config: YourTrainConfigDataclassCanAlsoGoHere,
jax_config: dict = {
"enable_x64": False,
"debug_nans": False,
"disable_jit": False,
}
):
from jax.config import config
for k, v in jax_config.items():
config.update("jax_" + k, v)
tyro.cli(train)
Or, to change the flags from --jax-config.enable-x64
to --jax.enable-x64
:
from typing_extensions import Annotated
import tyro
def train(
# train_config: YourTrainConfigDataclassCanAlsoGoHere,
jax_config: Annotated[dict, tyro.conf.arg(name="jax")] = {
"enable_x64": False,
"debug_nans": False,
"disable_jit": False,
}
):
from jax.config import config
for k, v in jax_config.items():
config.update("jax_" + k, v)
tyro.cli(train)
I've thought about this, and getting #33 merged should help. The main reason why this may not be feasible is that argument parsing depends heavily on rules for converting strings from the command-line to instances of annotated types. Unless we constrain the types supported by tyro
and handwrite rules there will be cases where we can't invert this conversion, and I haven't yet thought of a use-case that's compelling enough to motivate the development / maintenance effort.
Does that all make sense?
Interesting, when I wrote my original post I didn't fully understand tyro.cli
but I see it's more flexible than I once thought.
Makes sense, I can probably circle back and make a PR for this. The Jax config was a hypothetical, the reasons for needing this go beyond that, so it'd be nice to have this built in.
Yeah, that makes sense. There's probably a variation of this that's already possible, e.g., serialize the dataclass to YAML then use tyro.cli(..., default=...)
.
Thanks for quick response, I'll submit a PR soon for (2).
Hi! Just wanted to start off by saying that the library looks great 🎉. I'm trying to assess the feasibility of Tyro for my use-case, and there are a couple suggestions that maybe you'd consider:
It would be great to have more flexibility to get the parsed arguments without calling the function. I was thinking potentially having a function that returns something like
inspect.BoundArguments
after parsing is complete. This way we can have more flexibility on how we call the function.Another possibility might be an API where you don't provide a function, but just a dataclass that'll get hydrated from the CLI.EDIT: I see now that you can parse a dataclass usingtyro.cli
.It would be great if there was a flag like
strict
that controls whether argparse attempts to parse all arguments or only known arguments. For example, you could parse known arguments with Tyro then default back to absl to parse Jax config flags. To accomplish this, you'd need some way of knowing what Tyro wasn't able to parse. This is somewhat possible already by doing:_, unknown_args = tyro.extras.get_parser(f).parse_known_args()
unknown_args
fromsys.argv
and parse these separately.tyro.cli(f, args=...)
with the filtered args.It would be nice if there was an easier way to perform this. Perhaps this could be specifically added to (1) where if you specify
strict=False
it'll return a tupleBoundArguments, List[str]
where the second element is the unknown arguments.A crazy feature that might be useful to others is if there was a way to go from a dataclass instance (or annotation of a dataclass) to the (minimal?) CLI arguments needed to generate that dataclass. e.g., something like
tyro.extras.to_cli_args
. The use-case here is to keep all your configs / sweeps in Python, this would be specifically useful for sweeps. You could have a generator over (annotated?) dataclasses and then convert those to CLI arguments when launching a job.