brentyi / tyro

CLI interfaces & config objects, from types
https://brentyi.github.io/tyro
MIT License
467 stars 23 forks source link

List of dataclasses #151

Closed samuelstevens closed 1 month ago

samuelstevens commented 1 month ago

I have a config like this:

import dataclasses

import tyro

@dataclasses.dataclass
class Nested:
    a: int = 1

@dataclasses.dataclass
class Args:
    nested: list[Nested]

if __name__ == "__main__":
    args = tyro.cli(Args)
    print(args)

I would like to be able to specify a variable list of nested configs.

Something like --nested.0.a 0 --nested.1.a 1

But this script throws

tyro._instantiators.UnsupportedTypeAnnotationError: For variable-length sequences over nested types, we need a default value to infer length from.

So I change it to nested: list[Nested] = dataclasses.field(default_factory=list).

But then --nested is a fixed argument that cannot be parsed.

What is the best way to pass a list of nested objects as command line arguments?

brentyi commented 1 month ago

Hi @samuelstevens, I appreciate the succinct example!

For the built-in behavior of sequences like list[Nested], we only support overriding values in the field default; for example:

@dataclasses.dataclass
class Args:
    nested: list[Nested] = dataclasses.field(
        default_factory=lambda: [Nested(0), Nested(1), Nested(2)]
    )

would produce a CLI with three arguments: nested.py [-h] [--nested.0.a INT] [--nested.1.a INT] [--nested.2.a INT].

Unfortunately the CLI generation rules for specifying variable-length sequences of complex objects aren't implemented—I'd like to, but haven't found time to work through the many edge cases—so for now we need a default value to infer a static length from. I can make that error message clearer.

In the meantime, depending on your use case you might consider specifying a custom constructor:

from __future__ import annotations

import dataclasses
from typing import Annotated

import tyro
from tyro.conf import arg

@dataclasses.dataclass
class Nested:
    a: int = 1

    @staticmethod
    def list_constructor(a: list[int]) -> list[Nested]:
        return [Nested(i) for i in a]

@dataclasses.dataclass
class Args:
    nested: Annotated[list[Nested], arg(constructor=Nested.list_constructor)]

if __name__ == "__main__":
    args = tyro.cli(Args)
    print(args)
samuelstevens commented 1 month ago

Great, that's what I ended up doing. Thanks!