Open mirceamironenco opened 1 month ago
Interesting! If we added choices_factory
: would you also need a metavar_factory
then?
It's getting a bit crazy, but I'm wondering if this makes sense to you:
ModelName = Annotated[
str,
tyro.conf.arg(
help_behavior_hint=lambda df: f"(default: {df}, run entry.py model_registry)",
constructor_factory=lambda: Annotated[ # type: ignore
str,
tyro.constructors.PrimitiveConstructorSpec(
nargs=1,
metavar="{" + ",".join(registered_models()[:3]) + ",...}",
instance_from_str=lambda args: args[0],
is_instance=lambda instance: isinstance(instance, str)
and is_registered(instance),
str_from_instance=lambda instance: [instance],
choices=tuple(registered_models()),
),
],
),
]
Interesting! If we added choices_factory: would you also need a metavar_factory then?
I see your point, it's a bit much to have 2 of each.
It's getting a bit crazy, but I'm wondering if this makes sense to you:
I'm fine with this! I'm usually "over-using" constructor_factory anyway since it's so versatile. Truthfully my concern was mostly related to PrimitiveSpec
replacing more and more of the functionality of tyro.conf.arg
, and I work with a lot of "instantiate-at-the-last-moment types" as in the examples.
Makes sense! As another option, we could conceivably add something general purpose:
ModelName = Annotated[
str,
tyro.LazyAnnotation(
lambda: tyro.constructors.PrimitiveConstructorSpec(
nargs=1,
metavar="{" + ",".join(registered_models()[:3]) + ",...}",
instance_from_str=lambda args: args[0],
is_instance=lambda instance: isinstance(instance, str)
and is_registered(instance),
str_from_instance=lambda instance: [instance],
choices=tuple(registered_models()),
)
),
tyro.conf.arg(
help_behavior_hint=lambda df: f"(default: {df}, run entry.py model_registry)"
),
]
But I'm not sure if there'd be enough projects that benefit to justify.
@claude can you try implementing tyro.LazyAnnotation
and put it in a draft PR? I likely won't merge but am curious what it'd look like.
Claude finished @brentyi's task —— View job • claude/issue-308-20250529_214014
• Create PR ➔
I've successfully implemented tyro.LazyAnnotation
as requested! This addresses the use case in issue #308 where you need dynamic choices that can change after type definition.
tyro.LazyAnnotation
and tyro.constructors.LazyAnnotation
from typing_extensions import Annotated
import tyro
# Dynamic registry that can change at runtime
def get_models():
return ["model1", "model2", "model3"] # Could be dynamic
ModelName = Annotated[
str,
tyro.LazyAnnotation(
lambda: tyro.constructors.PrimitiveConstructorSpec(
nargs=1,
metavar="{" + ",".join(get_models()[:3]) + ",...}",
instance_from_str=lambda args: args[0],
is_instance=lambda instance: isinstance(instance, str),
str_from_instance=lambda instance: [instance],
choices=tuple(get_models()), # Evaluated at parse time!
)
),
]
def main(model: ModelName) -> str:
return model
if __name__ == "__main__":
tyro.cli(main)
src/tyro/constructors/_primitive_spec.py
- Added LazyAnnotation class and integrationsrc/tyro/constructors/__init__.py
- Added exportsrc/tyro/__init__.py
- Added to main tyro namespacetests/test_lazy_annotation.py
- Comprehensive test suitetest_lazy_annotation_example.py
- Working example**
Ran into a somewhat related issue, not sure if it's a bug but the behavior surprised me:
from typing import TYPE_CHECKING, Annotated, Literal, TypeAlias
import torch
import tyro
def make_dtype(dtype: Literal["float32", "bfloat16", "float64"]) -> torch.dtype:
torch_dtype = getattr(torch, dtype, None)
if isinstance(torch_dtype, torch.dtype):
return torch_dtype
raise ValueError(f"Expected valid torch.dtype, got {dtype} instead.")
DataType: TypeAlias = torch.dtype
if not TYPE_CHECKING:
DataType = Annotated[
torch.dtype,
tyro.constructors.PrimitiveConstructorSpec(
nargs=1,
metavar="{bfloat16,float32,float64}",
instance_from_str=lambda args: make_dtype(args[0]),
is_instance=lambda instance: isinstance(instance, torch.dtype),
str_from_instance=lambda instance: [str(instance).split(".")[-1]],
choices=("bfloat16", "float32", "float64"),
),
]
def main(dtype: DataType = torch.float16) -> None:
print(dtype)
if __name__ == "__main__":
tyro.cli(main)
Notice that the default for dtype
in main is torch.float16
which is an invalid choice, however if I run this script it works fine. If I run --help
the default also shows float16
:
usage: x.py [-h] [--dtype {bfloat16,float32,float64}]
╭─ options ─────────────────────────────────────────╮
│ -h, --help show this help message and exit │
│ --dtype {bfloat16,float32,float64} │
│ (default: float16) │
╰───────────────────────────────────────────────────╯
However, if I manually specify the option py x.py --dtype=float16
we error out:
╭─ Parsing error ────────────────────────────────────────────────────────────────────────────╮
│ Argument --dtype: invalid choice: 'float16' (choose from 'bfloat16', 'float32', 'float64') │
│ ────────────────────────────────────────────────────────────────────────────────────────── │
│ For full helptext, run x.py --help │
╰────────────────────────────────────────────────────────────────────────────────────────────╯
Since I'm providing str_from_instance
tyro has all it needs to be able to check if the default is valid via e.g. str_from_instance(instance) in choices
, which is what I was expecting given that choices
is restricted to tuple[str, ...]
rather than tuple[type,...]
. If this was intended behavior I would also maybe suggest changing choices: tuple[str, ...]
to choices: set[str]
to have a O(1) check, since I'm not sure why anyone would need to list the same choice twice?
make_dtype
also doesn't seem to be called so I can't place a check there either.
To confirm I understand the problem, the problem is that tyro
:
str_from_instance
to validate default values set from PythonAs a result:
choices=
is enforced when invalid values are passed from the command-line (--dtype=float16
raises an error)choices=
is not enforced when an invalid value is specified as a default (dtype=torch.float16
is OK)Is that right?
If so: I can I see why that's surprising. We could add the check, although for reasons related to Hyrum's law I want to be cautious about breaking people's existing code. I'm aware of in-the-wild cases where default values unfortunately don't match choices
[^1]. For better or worse, our current behavior is also consistent with argparse (where default=
doesn't necessarily need to be compatible with choices=
).
[^1]: Concretely: "gemma_300m_lora" here is not valid give this Literal["dummy", "gemma_300m", "gemma_2b", "gemma_2b_lora"]
annotation)
tuple[str, ...]
vs set[str]
suggestion: thanks!! I often reach for tuple[T, ...]
instead of set[T]
because sets are (i) unordered, which can be unideal aesthetically for help/error messages, and (ii) mutable, which can be inconvenient for static typing. If we want to fix the runtime here however, we could: (i) swap the annotation to collection.abc.Set
, where tuple[]
would still be a valid input (I think) or (ii) internally convert to set()
.
Interesting! If we added
choices_factory
: would you also need ametavar_factory
then?It's getting a bit crazy, but I'm wondering if this makes sense to you:
ModelName = Annotated[ str, tyro.conf.arg( help_behavior_hint=lambda df: f"(default: {df}, run entry.py model_registry)", constructor_factory=lambda: Annotated[ # type: ignore str, tyro.constructors.PrimitiveConstructorSpec( nargs=1, metavar="{" + ",".join(registered_models()[:3]) + ",...}", instance_from_str=lambda args: args[0], is_instance=lambda instance: isinstance(instance, str) and is_registered(instance), str_from_instance=lambda instance: [instance], choices=tuple(registered_models()), ), ], ), ]
I have a usecase trying to go in the 'other direction', where it seems I'm forced to use a PrimitieConstructerSpec but would prefer not to. Consider a type of this form:
@dataclass(kw_only=True, frozen=True)
class WandbConfig:
project: str = "proj"
run_name: str = "train"
run_id: str | None = None
group: str | None = None
job_type: str | None = None
Currently to make this available to the CLI I use a subcommand:
@dataclass(kw_only=True, frozen=True)
class WandbConfig:
project: str = "proj"
run_name: str = "train"
run_id: str | None = None
group: str | None = None
job_type: str | None = None
CLIWandbConfig = Annotated[WandbConfig, tyro.conf.subcommand(name="on")]
Now using CLIWandbConfig
works perfectly fine, but I'd like to use the (new) tyro.constructors.ConstructorRegistry
so the user doesn't have to potentially import 2 objects (WandbConfig
and CLIWandbConfig
) depending on the usage. I'd prefer if WandbConfig
is passed through my run_cli
function it will just run as if they passed in CLIWandbConfig
, so run_cli
can be thought of as something like:
def run_cli(foo: type) -> type:
# if type is WandbConfig, tyro treats it as CLIWandbConfig
return tyro.cli(foo, registry=global_scope_registry)
I assumed this is the main usecase for the registry. However, from the examples at https://brentyi.github.io/tyro/examples/custom_constructors/ it seems the return type for the registry decorated functions is either tyro.constructors.PrimitiveConstructorSpec | None:
or tyro.constructors.StructConstructorSpec | None:
.
Is there a way to implement something along the lines of:
# Create a custom registry, which stores constructor rules.
custom_registry = tyro.constructors.ConstructorRegistry()
# Define a rule that applies to all types that match `dict[str, Any]`.
@custom_registry.primitive_rule
def _(
type_info: tyro.constructors.PrimitiveTypeInfo,
) -> TypeForm | None:
# We return `None` if the rule does not apply.
if type_info.type != WandbConfig:
return None
return Annotated[WandbConfig, tyro.conf.subcommand(name="on")]
Neither PrimitiveConstructorSpec
nor StructConstructorSpec
seem like the right choice here.
I have slightly more complicated types I'd like to use like this (that will always be of the form Annotated[BaseType, tyro.conf.subcommand(...)]
or Annotated[BaseType, tyro.conf.arg(...)]
. An example which I hope is somewhat self-evident:
def _build_optimizer_union_type() -> type[OptimizerConfig]:
optimizers = [
Annotated[
SGDOptimizerConfig, tyro.conf.subcommand(name=SGDOptimizerConfig._cli_name)
],
Annotated[
AdamWOptimizerConfig,
tyro.conf.subcommand(name=AdamWOptimizerConfig._cli_name),
],
Annotated[
AdamOptimizerConfig,
tyro.conf.subcommand(name=AdamOptimizerConfig._cli_name),
],
Annotated[
AdamWFP8OptimizerConfig,
tyro.conf.subcommand(name=AdamWFP8OptimizerConfig._cli_name),
],
]
return Union[*optimizers] # type: ignore
CLIOptimizerConfig = Annotated[
OptimizerConfig,
tyro.conf.arg(constructor_factory=_build_optimizer_union_type),
]
You can think of OptimizerConfig
as e.g. being abstract, and the CLI registry being configured to use it as a Union[]
of subclasses. So in this case we would register for OptimizerConfig to be 'redirected' to Annotated[OptimizerConfig, tyro.conf.arg(constructor_factory=_build_optimizer_union_type)]
. Is this possible?
Would tyro.conf.configure
do what you're looking for?
Here's a relevant example from the tests:
And the usage that it produces:
There's already a way to do this, so this is just a feature request. I'm curious if you think it's worth adding/makes sense for the constructor spec API. Consider the following setting:
If we have some registry system which constructs a set of choices, and would like to also allow the user to add to the existing choices, the PrimitiveConstructorSpec has a limitation where
choices=
has already been defined, so in the above examplemodel2
will not be a possible choice.I can already accomplish this with a constructor_factory:
I'm thinking there is room for a
choices_factory: Callable[..., tuple[str, ...]] | None = None
option which would make the spec compatible with this use-case. Is this compatible with the purpose of the API or is the static nature intentional?