lebrice / SimpleParsing

Simple, Elegant, Typed Argument Parsing with argparse
MIT License
399 stars 49 forks source link

Allow argument dependencies #143

Open janvainer opened 2 years ago

janvainer commented 2 years ago

Is your feature request related to a problem? Please describe.

I am wondering if its possible to tie some argument values together. For example, when I train an Encoder-Decoder architecture, the encoder.output_dim should be the same like decoder.input_dim.

Describe the solution you'd like

I would like to specify only one of them in the CLI and have the other be set automatically. And if I accidentally specify both, I would like to be yelled at for doing something wrong.

How it could work:

dim: int = TiedParam(default=25)

@dataclass
class Encoder:
    output_dim: int = dim

@dataclass
class Decoder:
    input_dim: int = dim

Let me know what you think (or forgive my ignorance if its already possible!! :)) Thanks!

lebrice commented 2 years ago

Hey there @janvainer, thanks for posting this!

Hmm that's interesting. From what is currently implemented, you have two options to take a look at, none of which do exactly what you want... :disappointed:

This isn't directly related to simple-parsing, it's just a little feature for dataclasses that I've made into a shareable code snippet using GitHub Gist. Here's the idea:

@dataclass
class Bob(HasConditionalFields):
    name: str = "Bob Jones"
    age: int = 26

    gamer_name: str = conditional_field(
        lambda name, age: f"xXx_{name}_{2022-age}_xXx",
    )

    email: str = conditional_field(
        lambda gamer_name: f"{gamer_name}@gmail.com"
    )

bob = Bob()
print(bob)
# Bob(name='Bob Jones', age=26, gamer_name='xXx_Bob Jones_1995_xXx',
#     email='Bob Jones@gmail.com', gamer_email='xXx_Bob Jones_1995_xXx@gmail.com')

The problem with this though is that I don't think it works if the fields are on different dataclasses.. You might be able to tweak the conditional_field function to achieve it though. If you do, please do let me know!

from argparse import Namespace
from dataclasses import dataclass
from simple_parsing import ArgumentParser, field
from simple_parsing.wrappers.field_wrapper import ArgumentGenerationMode

@dataclass
class Encoder:
    output_dim: int = field(default=128, dest="model.decoder.input_dim")

@dataclass
class Decoder:
    input_dim: int = field(default=128, cmd=False)

@dataclass
class ModelConfig:
    encoder: Encoder = field(default_factory=Encoder)
    decoder: Decoder = field(default_factory=Decoder)

def test_repro_issue_143():
    parser = ArgumentParser(argument_generation_mode=ArgumentGenerationMode.NESTED)
    parser.add_arguments(ModelConfig, dest="model")
    args = parser.parse_args("--model.encoder.output_dim 256".split())

    # Doesn't exactly work at the moment:
    assert args == Namespace(
        model=ModelConfig(encoder=Encoder(output_dim=128), decoder=Decoder(input_dim=128)),
        **{"model.decoder.input_dim": 256}
    )

This doesn't really do what you want, and doesn't really work as intended, in its current state. (I wrote this waaaaay back, and haven't used it or tested it since..)

Hope this is somewhat helpful. I'd be really curious to hear some more ideas on how we could design this! The TiedField idea seems interesting. Perhaps we could also make use of type annotations for this somehow?

Let me know what you think! :)

janvainer commented 2 years ago

Hi @lebrice, thank you for your ideas and the code snippets! :) After considering it a bit, definging the tied parameters in the dataclasses directly is probably not the best way to go. But how about tying the parameters together in the parser? Something like this might work:

parser = ArgumentParser(argument_generation_mode=ArgumentGenerationMode.NESTED)
parser.add_arguments(Encoder, dest="enc")
parser.add_arguments(Decoder, dest="enc")
parser.tie_arguments("enc.output_dim", "dec.input_dim")

The tie_arguments would check that the defaults match (if provided) and parser would fail if user specifies both variables but with different values. If user specifies only one of them, the other one would get copied over. The user could specify multiple tied items:

# tie one group of parameters - ensures that connecting the modules does not raise shape errors
parser.tie_arguments("enc.output_dim", "dec.input_dim", "speaker_embedding.channels")

# tie another group of parameters - all model parts use the same dropout value
parser.tie_arguments("enc.dropout", "dec.dropout")

WDYT?

mauvilsa commented 2 years ago

@janvainer what you propose already exists in jsonargparse, see argument-linking. You might want to try that out.

lebrice commented 2 years ago

@janvainer That looks pretty good to me, and doesn't seem too hard to implement either. I'll keep you posted, not sure when exactly I'll get to work on this. I'd also be grateful if others would like to take a crack at it! :smile:

Ohai there @mauvilsa , thanks for paying us a visit! I wasn't aware of JsonArgparse, looks very neat! Let me know if you'd like to chat, there could be some things we could potentially contribute to JsonArgparse and vice-versa :slightly_smiling_face: !

janvainer commented 2 years ago

@mauvilsa looks cool! Thats basically what I had in mind 😄 @lebrice thanks a lot. I unfortunately do not have the capacity to contribute this RN, but I will be happy to test it out and report bugs/issues etc.

mauvilsa commented 2 years ago

@lebrice I think some time ago I messaged you mentioning the overlap of SimpleParsing and jsonargparse. Probably in a thread related to pytorch-lightning. Or maybe I am confusing it with some other message I sent. Anyway, I am in the pytorch-lightning slack if you want to message there.

jsonargparse is the library that is used under the hood for LightningCLI. In that documentation you can also see the explanation for the link_arguments. After all this was implemented because it was needed there. Thought, the concept goes much further than what was proposed here, since it also considers class instantiation and nested classes used via composition.