brentyi / tyro

CLI interfaces & config objects, from types
https://brentyi.github.io/tyro
MIT License
723 stars 33 forks source link

`choices_factory` for PrimitiveConstructorSpec #308

Open mirceamironenco opened 1 month ago

mirceamironenco commented 1 month ago

There's already a way to do this, so this is just a feature request. I'm curious if you think it's worth adding/makes sense for the constructor spec API. Consider the following setting:

from typing import Annotated

import torch.nn as nn
import tyro

model_registry: dict[str, nn.Module] = {}

def register_model(name: str, model: nn.Module) -> None:
    if name not in model_registry:
        model_registry[name] = model

def registered_models() -> list[str]:
    return list(model_registry.keys())

def is_registered(name: str) -> bool:
    return name in model_registry

class Model1(nn.Module): ...

register_model("model1", Model1())

# Some other file

ModelName = Annotated[
    str,
    tyro.constructors.PrimitiveConstructorSpec(
        nargs=1,
        metavar="{" + ",".join(registered_models()[:3]) + ",...}",
        instance_from_str=lambda args: args[0],
        is_instance=lambda instance: isinstance(instance, str)
        and is_registered(instance),
        str_from_instance=lambda instance: [instance],
        choices=tuple(registered_models()),
    ),
    tyro.conf.arg(
        help_behavior_hint=lambda df: f"(default: {df}, run entry.py model_registry)"
    ),
]

# User defines a new model after ModelName type has been defined
class Model2(nn.Module): ...

register_model("model2", Model2())

# model2 will not show up as a choice
def main(model: ModelName) -> None:
    print(model)

if __name__ == "__main__":
    tyro.cli(main)

If we have some registry system which constructs a set of choices, and would like to also allow the user to add to the existing choices, the PrimitiveConstructorSpec has a limitation where choices= has already been defined, so in the above example model2 will not be a possible choice.

I can already accomplish this with a constructor_factory:

from typing import Annotated

import torch.nn as nn
import tyro

model_registry: dict[str, nn.Module] = {}

def register_model(name: str, model: nn.Module) -> None:
    if name not in model_registry:
        model_registry[name] = model

def registered_models() -> list[str]:
    return list(model_registry.keys())

def is_registered(name: str) -> bool:
    return name in model_registry

class Model1(nn.Module): ...

register_model("model1", Model1())

# Some other file

def build_registry_literal() -> type[str]:
    return tyro.extras.literal_type_from_choices(registered_models())

ModelName = Annotated[str, tyro.conf.arg(constructor_factory=build_registry_literal)]

# User defines a new model
class Model2(nn.Module): ...

register_model("model2", Model2())

def main(model: ModelName) -> None:
    print(model)

if __name__ == "__main__":
    tyro.cli(main)

I'm thinking there is room for a choices_factory: Callable[..., tuple[str, ...]] | None = None option which would make the spec compatible with this use-case. Is this compatible with the purpose of the API or is the static nature intentional?

brentyi commented 1 month ago

Interesting! If we added choices_factory: would you also need a metavar_factory then?

It's getting a bit crazy, but I'm wondering if this makes sense to you:

ModelName = Annotated[
    str,
    tyro.conf.arg(
        help_behavior_hint=lambda df: f"(default: {df}, run entry.py model_registry)",
        constructor_factory=lambda: Annotated[  # type: ignore
            str,
            tyro.constructors.PrimitiveConstructorSpec(
                nargs=1,
                metavar="{" + ",".join(registered_models()[:3]) + ",...}",
                instance_from_str=lambda args: args[0],
                is_instance=lambda instance: isinstance(instance, str)
                and is_registered(instance),
                str_from_instance=lambda instance: [instance],
                choices=tuple(registered_models()),
            ),
        ],
    ),
]
mirceamironenco commented 1 month ago

Interesting! If we added choices_factory: would you also need a metavar_factory then?

I see your point, it's a bit much to have 2 of each.

It's getting a bit crazy, but I'm wondering if this makes sense to you:

I'm fine with this! I'm usually "over-using" constructor_factory anyway since it's so versatile. Truthfully my concern was mostly related to PrimitiveSpec replacing more and more of the functionality of tyro.conf.arg, and I work with a lot of "instantiate-at-the-last-moment types" as in the examples.

brentyi commented 1 month ago

Makes sense! As another option, we could conceivably add something general purpose:

ModelName = Annotated[
    str,
    tyro.LazyAnnotation(
        lambda: tyro.constructors.PrimitiveConstructorSpec(
            nargs=1,
            metavar="{" + ",".join(registered_models()[:3]) + ",...}",
            instance_from_str=lambda args: args[0],
            is_instance=lambda instance: isinstance(instance, str)
            and is_registered(instance),
            str_from_instance=lambda instance: [instance],
            choices=tuple(registered_models()),
        )
    ),
    tyro.conf.arg(
        help_behavior_hint=lambda df: f"(default: {df}, run entry.py model_registry)"
    ),
]

But I'm not sure if there'd be enough projects that benefit to justify.

brentyi commented 1 month ago

@claude can you try implementing tyro.LazyAnnotation and put it in a draft PR? I likely won't merge but am curious what it'd look like.

claude[bot] commented 1 month ago

Claude finished @brentyi's task —— View jobclaude/issue-308-20250529_214014Create PR ➔


Todo List:

Implementation Complete ✅

I've successfully implemented tyro.LazyAnnotation as requested! This addresses the use case in issue #308 where you need dynamic choices that can change after type definition.

Key Features Implemented:

Usage Example:

from typing_extensions import Annotated
import tyro

# Dynamic registry that can change at runtime
def get_models():
    return ["model1", "model2", "model3"]  # Could be dynamic

ModelName = Annotated[
    str,
    tyro.LazyAnnotation(
        lambda: tyro.constructors.PrimitiveConstructorSpec(
            nargs=1,
            metavar="{" + ",".join(get_models()[:3]) + ",...}",
            instance_from_str=lambda args: args[0],
            is_instance=lambda instance: isinstance(instance, str),
            str_from_instance=lambda instance: [instance],
            choices=tuple(get_models()),  # Evaluated at parse time!
        )
    ),
]

def main(model: ModelName) -> str:
    return model

if __name__ == "__main__":
    tyro.cli(main)

Files Modified:


**

mirceamironenco commented 1 month ago

Ran into a somewhat related issue, not sure if it's a bug but the behavior surprised me:

from typing import TYPE_CHECKING, Annotated, Literal, TypeAlias

import torch
import tyro

def make_dtype(dtype: Literal["float32", "bfloat16", "float64"]) -> torch.dtype:
    torch_dtype = getattr(torch, dtype, None)

    if isinstance(torch_dtype, torch.dtype):
        return torch_dtype

    raise ValueError(f"Expected valid torch.dtype, got {dtype} instead.")

DataType: TypeAlias = torch.dtype

if not TYPE_CHECKING:
    DataType = Annotated[
        torch.dtype,
        tyro.constructors.PrimitiveConstructorSpec(
            nargs=1,
            metavar="{bfloat16,float32,float64}",
            instance_from_str=lambda args: make_dtype(args[0]), 
            is_instance=lambda instance: isinstance(instance, torch.dtype),
            str_from_instance=lambda instance: [str(instance).split(".")[-1]],
            choices=("bfloat16", "float32", "float64"),
        ),
    ]

def main(dtype: DataType = torch.float16) -> None:
    print(dtype)

if __name__ == "__main__":
    tyro.cli(main)

Notice that the default for dtype in main is torch.float16 which is an invalid choice, however if I run this script it works fine. If I run --help the default also shows float16:

usage: x.py [-h] [--dtype {bfloat16,float32,float64}]

╭─ options ─────────────────────────────────────────╮
│ -h, --help        show this help message and exit │
│ --dtype {bfloat16,float32,float64}                │
│                   (default: float16)              │
╰───────────────────────────────────────────────────╯

However, if I manually specify the option py x.py --dtype=float16 we error out:

╭─ Parsing error ────────────────────────────────────────────────────────────────────────────╮
│ Argument --dtype: invalid choice: 'float16' (choose from 'bfloat16', 'float32', 'float64') │
│ ────────────────────────────────────────────────────────────────────────────────────────── │
│ For full helptext, run x.py --help                                                         │
╰────────────────────────────────────────────────────────────────────────────────────────────╯

Since I'm providing str_from_instance tyro has all it needs to be able to check if the default is valid via e.g. str_from_instance(instance) in choices, which is what I was expecting given that choices is restricted to tuple[str, ...] rather than tuple[type,...]. If this was intended behavior I would also maybe suggest changing choices: tuple[str, ...] to choices: set[str] to have a O(1) check, since I'm not sure why anyone would need to list the same choice twice?

make_dtype also doesn't seem to be called so I can't place a check there either.

brentyi commented 1 month ago

To confirm I understand the problem, the problem is that tyro:

As a result:

Is that right?

If so: I can I see why that's surprising. We could add the check, although for reasons related to Hyrum's law I want to be cautious about breaking people's existing code. I'm aware of in-the-wild cases where default values unfortunately don't match choices[^1]. For better or worse, our current behavior is also consistent with argparse (where default= doesn't necessarily need to be compatible with choices=).

[^1]: Concretely: "gemma_300m_lora" here is not valid give this Literal["dummy", "gemma_300m", "gemma_2b", "gemma_2b_lora"] annotation)

tuple[str, ...] vs set[str] suggestion: thanks!! I often reach for tuple[T, ...] instead of set[T] because sets are (i) unordered, which can be unideal aesthetically for help/error messages, and (ii) mutable, which can be inconvenient for static typing. If we want to fix the runtime here however, we could: (i) swap the annotation to collection.abc.Set, where tuple[] would still be a valid input (I think) or (ii) internally convert to set().

mirceamironenco commented 6 hours ago

Interesting! If we added choices_factory: would you also need a metavar_factory then?

It's getting a bit crazy, but I'm wondering if this makes sense to you:

ModelName = Annotated[ str, tyro.conf.arg( help_behavior_hint=lambda df: f"(default: {df}, run entry.py model_registry)", constructor_factory=lambda: Annotated[ # type: ignore str, tyro.constructors.PrimitiveConstructorSpec( nargs=1, metavar="{" + ",".join(registered_models()[:3]) + ",...}", instance_from_str=lambda args: args[0], is_instance=lambda instance: isinstance(instance, str) and is_registered(instance), str_from_instance=lambda instance: [instance], choices=tuple(registered_models()), ), ], ), ]

I have a usecase trying to go in the 'other direction', where it seems I'm forced to use a PrimitieConstructerSpec but would prefer not to. Consider a type of this form:

@dataclass(kw_only=True, frozen=True)
class WandbConfig:
    project: str = "proj"

    run_name: str = "train"

    run_id: str | None = None

    group: str | None = None

    job_type: str | None = None

Currently to make this available to the CLI I use a subcommand:


@dataclass(kw_only=True, frozen=True)
class WandbConfig:
    project: str = "proj"

    run_name: str = "train"

    run_id: str | None = None

    group: str | None = None

    job_type: str | None = None

CLIWandbConfig = Annotated[WandbConfig, tyro.conf.subcommand(name="on")]

Now using CLIWandbConfig works perfectly fine, but I'd like to use the (new) tyro.constructors.ConstructorRegistry so the user doesn't have to potentially import 2 objects (WandbConfig and CLIWandbConfig) depending on the usage. I'd prefer if WandbConfig is passed through my run_cli function it will just run as if they passed in CLIWandbConfig, so run_cli can be thought of as something like:

def run_cli(foo: type) -> type: 
    # if type is WandbConfig, tyro treats it as CLIWandbConfig
    return tyro.cli(foo, registry=global_scope_registry)

I assumed this is the main usecase for the registry. However, from the examples at https://brentyi.github.io/tyro/examples/custom_constructors/ it seems the return type for the registry decorated functions is either tyro.constructors.PrimitiveConstructorSpec | None: or tyro.constructors.StructConstructorSpec | None:.

Is there a way to implement something along the lines of:

# Create a custom registry, which stores constructor rules.

custom_registry = tyro.constructors.ConstructorRegistry()

# Define a rule that applies to all types that match `dict[str, Any]`.

@custom_registry.primitive_rule
def _(
    type_info: tyro.constructors.PrimitiveTypeInfo,
) -> TypeForm | None:
    # We return `None` if the rule does not apply.
    if type_info.type != WandbConfig:
        return None

    return Annotated[WandbConfig, tyro.conf.subcommand(name="on")]

Neither PrimitiveConstructorSpec nor StructConstructorSpec seem like the right choice here.

I have slightly more complicated types I'd like to use like this (that will always be of the form Annotated[BaseType, tyro.conf.subcommand(...)] or Annotated[BaseType, tyro.conf.arg(...)]. An example which I hope is somewhat self-evident:


def _build_optimizer_union_type() -> type[OptimizerConfig]:
    optimizers = [
        Annotated[
            SGDOptimizerConfig, tyro.conf.subcommand(name=SGDOptimizerConfig._cli_name)
        ],
        Annotated[
            AdamWOptimizerConfig,
            tyro.conf.subcommand(name=AdamWOptimizerConfig._cli_name),
        ],
        Annotated[
            AdamOptimizerConfig,
            tyro.conf.subcommand(name=AdamOptimizerConfig._cli_name),
        ],
        Annotated[
            AdamWFP8OptimizerConfig,
            tyro.conf.subcommand(name=AdamWFP8OptimizerConfig._cli_name),
        ],
    ]
    return Union[*optimizers]  # type: ignore

CLIOptimizerConfig = Annotated[
    OptimizerConfig,
    tyro.conf.arg(constructor_factory=_build_optimizer_union_type),
]

You can think of OptimizerConfig as e.g. being abstract, and the CLI registry being configured to use it as a Union[] of subclasses. So in this case we would register for OptimizerConfig to be 'redirected' to Annotated[OptimizerConfig, tyro.conf.arg(constructor_factory=_build_optimizer_union_type)]. Is this possible?

brentyi commented 5 hours ago

Would tyro.conf.configure do what you're looking for?

Here's a relevant example from the tests:

https://github.com/brentyi/tyro/blob/19c7af40ff978cbb444eff3abb3154bbf7a7dfd0/tests/test_conf.py#L843-L847

And the usage that it produces:

https://github.com/brentyi/tyro/blob/19c7af40ff978cbb444eff3abb3154bbf7a7dfd0/tests/test_conf.py#L878-L888