omni-us / jsonargparse

Implement minimal boilerplate CLIs derived from type hints and parse from command line, config files and environment variables
https://jsonargparse.readthedocs.io
MIT License
326 stars 49 forks source link

Incorrect instantiation order with nested dependency injection #632

Open jkrude opened 11 hours ago

jkrude commented 11 hours ago

🐛 Bug report

First of all, this is a really cool project, and I am enjoying it a lot 👍 . This might be quite niche to be honest, but I encountered unexpected behavior when working on nested classes. In my use case, the user can choose from multiple loss classes (Strategy in the example), which on their own have another nested dependency injection (NestedStrategy). Furthermore, some arguments of the loss are linked to (StrategyParameter) and the nested strategy has also arguments provided from somewhere else (NestedParameter).

Now in this setting I found that the instantiation order induced by the topological sort of the dependency graph does not enforce that NestedParameter is instantiated before Strategy. But when instantiating the Strategy object, all actions which target the nested strategy are triggered too. So in particular the link

parser.link_arguments(
        "nested_param",
        "root.strategy.init_args.nested.init_args.nested_param",
        apply_on="instantiate",
        compute_fn=compute_fn,
    )

is triggered however before NestedStrategy was instantiated and so the compute_fn is called with the Namespace object.

I am not sure whether the graph should contain the dependency or the instantiation logic should not instantiate the actions leading to the nested class? In ActionLink.apply_instantiation_links there is this predicate is_nested_instantiation_link (returning False) which hints at the second case being true?

This is how the dependency graph looks like as constructed in ActionLink.instantiation_order:

stateDiagram-v2
    nested_param --> root.strategy.init_args.nested,
    param --> root.strategy

To reproduce

I tried to keep the example as small as possible. We have two dependency injected classes Strategy and NestedStrategy and two objects which are used for linking StrategyParameter and NestedParameter.

from jsonargparse import ArgumentParser

class NestedStrategy:

    def __init__(self, nested_param: str):
        self.nested_param = nested_param

class Strategy:

    def __init__(self, param: str, nested: NestedStrategy):
        self.param = param
        self.nested = nested

class Root:

    def __init__(self, strategy: Strategy):
        self.strategy = strategy

class NestedParameter:
    def __init__(self, something_else: str):
        self._something_else = something_else

class StrategyParameter:

    def __init__(self, something: str):
        self.something = something

def compute_fn(nested_param: NestedParameter):
    assert isinstance(
        nested_param, NestedParameter
    ), f"Got wrong type {type(nested_param)}"
    return nested_param

if __name__ == "__main__":
    parser = ArgumentParser()

    parser.add_class_arguments(Root, "root")
    parser.add_class_arguments(NestedParameter, "nested_param")
    parser.add_class_arguments(StrategyParameter, "param")
    parser.link_arguments(
        "nested_param",
        "root.strategy.init_args.nested.init_args.nested_param",
        apply_on="instantiate",
        compute_fn=compute_fn,
    )
    parser.link_arguments(
        "param.param", "root.strategy.init_args.param", apply_on="instantiate"
    )
    parser.add_argument("--config", action="config")
    cfg = parser.parse_args(["--config", "config.yaml"])
    init_cfg = parser.instantiate_classes(cfg)
root:
  strategy:
    class_path: Strategy
    init_args:
      nested:
        class_path: NestedStrategy
nested_param:
  something_else: "Something Else"

param:
  something: "something"

Output

ValueError: Call to compute_fn of link 'compute_fn(nested_param) --> root.strategy.init_args.nested.init_args.nested_param' with args (Namespace(something_else='Something Else')) failed: Got wrong type. Expected NestedParameter but got <class 'jsonargparse._namespace.Namespace'>

Expected behavior

The NestedParameter class is instantiated before passed as parameter to the compute_fn. Note, the compute_fn is only to showcase the problem but not necessary to reproduce the bug.

Environment

mauvilsa commented 3 hours ago

Thank you for reporting! I will look at it.