Closed tungalbert99 closed 8 months ago
This comes with unfortunate side effect that I can't pass in "hi" or "bye" directly to the run function hm
Hi Albert!
Thanks for the trying some bleeding edge stuff... just to confirm, is this the behavior you're looking for?
Code: https://gist.github.com/brentyi/91ae652bd0a986ccc8b182bfffa7cae7
Yup that looks good to me!
I'm also wondering if there's a way to inject "hi" into what gets placed into run() ?
Sorry, do you mind rephrasing or giving an example? Not fully following what you mean by "injecting 'hi'" or "what gets placed into run()" is referring to.
Yep! Essentially I need to package arguments from run_protocol
into an endpoint that hits /{protocol_name}/{config}
. The way that run_protocol is set up:
def run_protocol(config: DynamicUnion):
print(config) # this only prints out the config from the dict commands which is dict(protocol_name: config)
# ideally I want to also access protocol_name
I can't really do a reverse search on the dict either because multiple protocol names can share the same type of config. So I need a way to propagate the protocol_name down to run_protocol
I see, so as an alternative example you might have something like:
from typing import Annotated, Union
from pydantic import BaseModel
import tyro
class Config(BaseModel):
payload: str
@tyro.conf.configure(tyro.conf.OmitSubcommandPrefixes)
def run(
config: Union[
Annotated[
Config,
tyro.conf.subcommand(name="hi", default=Config(payload="hello world")),
],
Annotated[
Config,
tyro.conf.subcommand(name="bye", default=Config(payload="hello world")),
],
],
) -> None:
print(config)
tyro.cli(run)
which exposes two subcommands, hi
and bye
, which you want to be able to distinguish between. And an isinstance()
check doesn't work because (as in this snippet) you want to use the same Config
class/type for both subcommands, is that right?
If you can't make a new type for each endpoint, maybe tyro.conf.Suppress[]
could be useful for sneaking in information that's not exposed to the CLI. This will have the same CLI interface as above, but add an additional .endpoint
field which is visible to the code:
from typing import Annotated, Union
from pydantic import BaseModel
import tyro
class Config(BaseModel):
payload: str
endpoint: tyro.conf.Suppress[str]
@tyro.conf.configure(tyro.conf.OmitSubcommandPrefixes)
def run(
config: Union[
Annotated[
Config,
tyro.conf.subcommand(
name="hi", default=Config(payload="hello world", endpoint="hi")
),
],
Annotated[
Config,
tyro.conf.subcommand(
name="bye", default=Config(payload="hello world", endpoint="bye")
),
],
],
) -> None:
print(config.endpoint)
tyro.cli(run)
That works! I have to do the crazy dynamic union but it works for now :)
I was able to get this bit of code to work using some unholy magic of dynamic Unions, but wasn't able to get anything working with
subcommand_from_dict
since I did some nesting. Is there a better way to do this in Tyro?