brentyi / tyro

CLI interfaces & config objects, from types
https://brentyi.github.io/tyro
MIT License
467 stars 23 forks source link

question: Clarity on dynamic subcommands #112

Closed tungalbert99 closed 8 months ago

tungalbert99 commented 8 months ago

I was able to get this bit of code to work using some unholy magic of dynamic Unions, but wasn't able to get anything working with subcommand_from_dict since I did some nesting. Is there a better way to do this in Tyro?

from typing import Annotated, Any, Union

import tyro
from pydantic import BaseModel

class Hi(BaseModel):
    test_hi: str = "hi"

class Bye(BaseModel):
    test_bye: str = "bye"

commands = {
    "hi": Hi,
    "bye": Bye,
    # Add more commands here...
}

# Create a list of Annotated types
annotated_types = [
    Annotated[cls, tyro.conf.subcommand(name=name, constructor=cls)]
    for name, cls in commands.items()
]

# Create a dynamic Union type
DynamicUnion = Union[*annotated_types]  # type: ignore

def run(choice: DynamicUnion):
    return choice

def abort():
    print("Abort")

def clear():
    print("Clear")

def status():
    print("Status")

def main(
    protocol: Union[
        Annotated[
            Any,
            tyro.conf.subcommand(name="run", constructor=run),
            tyro.conf.OmitArgPrefixes,
            tyro.conf.OmitSubcommandPrefixes,
        ],
        Annotated[
            Any,
            tyro.conf.subcommand(name="abort", constructor=abort),
        ],
        Annotated[
            Any,
            tyro.conf.subcommand(name="clear", constructor=clear),
        ],
        Annotated[
            Any,
            tyro.conf.subcommand(name="status", constructor=status),
        ],
    ]
):
    print(protocol)

if __name__ == "__main__":
    tyro.cli(main)
tungalbert99 commented 8 months ago

This comes with unfortunate side effect that I can't pass in "hi" or "bye" directly to the run function hm

brentyi commented 8 months ago

Hi Albert!

Thanks for the trying some bleeding edge stuff... just to confirm, is this the behavior you're looking for?

image

Code: https://gist.github.com/brentyi/91ae652bd0a986ccc8b182bfffa7cae7

tungalbert99 commented 8 months ago

Yup that looks good to me!

I'm also wondering if there's a way to inject "hi" into what gets placed into run() ?

brentyi commented 8 months ago

Sorry, do you mind rephrasing or giving an example? Not fully following what you mean by "injecting 'hi'" or "what gets placed into run()" is referring to.

tungalbert99 commented 8 months ago

Yep! Essentially I need to package arguments from run_protocol into an endpoint that hits /{protocol_name}/{config}. The way that run_protocol is set up:

def run_protocol(config: DynamicUnion): 
      print(config) # this only prints out the config from the dict commands which is dict(protocol_name: config)
      # ideally I want to also access protocol_name

I can't really do a reverse search on the dict either because multiple protocol names can share the same type of config. So I need a way to propagate the protocol_name down to run_protocol

brentyi commented 8 months ago

I see, so as an alternative example you might have something like:

from typing import Annotated, Union

from pydantic import BaseModel

import tyro

class Config(BaseModel):
    payload: str

@tyro.conf.configure(tyro.conf.OmitSubcommandPrefixes)
def run(
    config: Union[
        Annotated[
            Config,
            tyro.conf.subcommand(name="hi", default=Config(payload="hello world")),
        ],
        Annotated[
            Config,
            tyro.conf.subcommand(name="bye", default=Config(payload="hello world")),
        ],
    ],
) -> None:
    print(config)

tyro.cli(run)

which exposes two subcommands, hi and bye, which you want to be able to distinguish between. And an isinstance() check doesn't work because (as in this snippet) you want to use the same Config class/type for both subcommands, is that right?

If you can't make a new type for each endpoint, maybe tyro.conf.Suppress[] could be useful for sneaking in information that's not exposed to the CLI. This will have the same CLI interface as above, but add an additional .endpoint field which is visible to the code:

from typing import Annotated, Union

from pydantic import BaseModel

import tyro

class Config(BaseModel):
    payload: str
    endpoint: tyro.conf.Suppress[str]

@tyro.conf.configure(tyro.conf.OmitSubcommandPrefixes)
def run(
    config: Union[
        Annotated[
            Config,
            tyro.conf.subcommand(
                name="hi", default=Config(payload="hello world", endpoint="hi")
            ),
        ],
        Annotated[
            Config,
            tyro.conf.subcommand(
                name="bye", default=Config(payload="hello world", endpoint="bye")
            ),
        ],
    ],
) -> None:
    print(config.endpoint)

tyro.cli(run)
tungalbert99 commented 8 months ago

That works! I have to do the crazy dynamic union but it works for now :)