Closed Tomiinek closed 1 year ago
Hey @Tomiinek, I've checked the master code and the subgroups has been refactored in #185. It seems this issue had been addressed. Your example passed in the current master.
Hello, thank you for your prompt reply.
I am confused :confused: Does it really work for you?
It actually does not work for me from the current master, the example above gives me:
File ".../simple_parsing/helpers/fields.py", line 416, in subgroups
if default is not MISSING and default not in subgroups:
TypeError: unhashable type: 'ModelAConfig'
If I make the dataclasses hashable:
@dataclass(unsafe_hash=True)
class ModelAConfig(ModelConfig):
lr: float = 3e-4
optimizer: str = "Adam"
betas: tuple[float, float] = 0.9, 0.999
@dataclass(unsafe_hash=True)
class ModelBConfig(ModelConfig):
lr: float = 1e-3
optimizer: str = "SGD"
momentum: float = 1.234
I get
File ".../simple_parsing/helpers/fields.py", line 417, in subgroups
raise ValueError("default must be a key in the subgroups dict!")
ValueError: default must be a key in the subgroups dict!
so I have to change the default to a string key:
@dataclass
class Config:
# Which model to use
model: ModelConfig = subgroups(
{"model_a": ModelAConfig, "model_b": ModelBConfig},
default="model_a",
)
But I lost the ability to set some parameters in the default, for example ModelAConfig(lr=0.1)
(which was previously possible, I do not want to set it directly in ModelAConfig
because it is used on multiple places). Is this your intention? If so, what is the recommended way to set the defaults? Some joggling with parser.set_defaults
?
Please refere to #186 and field [Python Docs]
for how to initialize a field. The subgroups
returns a customized Field
for you with some additional checkings and specialized metadata.
You can use the default_factory as indicated by field [Python Docs]
: "If provided, it must be a zero-argument callable that will be called when a default value is needed for this field. "
Here are some examples to achieve your goal:
@dataclass
class Config:
# Which model to use
model: ModelConfig = subgroups(
{"model_a": ModelAConfig, "model_b": ModelBConfig},
default_factory=ModelAConfig,
)
@dataclass
class Config:
# Which model to use
model: ModelConfig = subgroups(
{"model_a": ModelAConfig, "model_b": ModelBConfig},
default_factory=lambda : ModelAConfig(lr=0.1),
)
Very nice :slightly_smiling_face:
Is it supposed to work out of the box? Asking because there seems to be this check (so the factory itself has to be a value in the dict so I do not see the way to set different arguments here): https://github.com/lebrice/SimpleParsing/blob/5042cb4f863fbf5e1364871de0ecb8f766213e78/simple_parsing/helpers/fields.py#L418-L421
So
@dataclass
class Config:
# Which model to use
model: ModelConfig = subgroups(
{"model_a": ModelAConfig, "model_b": ModelBConfig},
default_factory=lambda : ModelAConfig(lr=0.1),
)
actually results in an ValueError
.
Maybe something like this could work?
if default_factory is not MISSING and type(default_factory()) not in subgroups.values():
...
Very nice 🙂
Is it supposed to work out of the box? Asking because there seems to be this check (so the factory itself has to be a value in the dict so I do not see the way to set different arguments here):
So
@dataclass class Config: # Which model to use model: ModelConfig = subgroups( {"model_a": ModelAConfig, "model_b": ModelBConfig}, default_factory=lambda : ModelAConfig(lr=0.1), )
actually results in an
ValueError
.Maybe something like this could work?
if default_factory is not MISSING and type(default_factory()) not in subgroups.values(): ...
What version of simple-parsing are you working on? The default_factory is not implemented in 0.0.21.post1 as per #187, but it is now implemented in master
branch by #185.
I hope it is the current master (5042cb4f863fbf5e1364871de0ecb8f766213e78) aka simple-parsing===0.0.21.post1.post4-g5042cb4
with python 3.10.8
Thanks for pointing this out. That is very interesting I overlooked this part. Please ignore my above reply. The question is why do we need to set default values which are different from corresponding dataclasses defaults?
Okay, my situation is the following :slightly_smiling_face:
My dataclasses are auto-generated from "serializable" classes, or better said from arguments of their constructors. For example:
class AdamW(TorchAdamW, SerializableObject, ignore_args=["self", "params"]):
"""Wrapper around AdamW to name its parameters for serialization."""
def __init__(
self,
params,
lr: float = 1e-3,
eps: float = 1e-6,
weight_decay: float = 0.01,
amsgrad: bool = False,
):
super().__init__(
params,
lr=lr,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad,
)
produces a dataclass like this one:
@dataclass
class AdamWConfig:
lr: float = 0.001
eps: float = 1e-06
weight_decay: float = 0.01
amsgrad: bool = False
This config dataclass is used from different different projects, so I do not want to change the defaults here. At the same time, each project has its own config related to optimizers, something like:
class OptimizerConfig(Config):
opt: Union[AdamW.get_config_class(), FusedLAMB.get_config_class()] = subgroups(
{
"adamw": AdamW.get_config_class(),
"fusedlamb": FusedLAMB.get_config_class(),
},
# default_factory=lambda: AdamW.get_config_class()(lr=0.005),
)
and I believe that here, in this optimizer config, is a suitable place for overriding the project-related default setting.
Does it make sense or is this setup weird?
I don't mind using parser.set_defaults
or something like that for setting the defaults, but I do not see how to do that for subgroups.
Hi @Tomiinek , thanks for posting this.
Yes you're right, I've temporarily restricted what can be passed to the subgroups
function to only dataclass types. I'm working on this today.
By the way, your dynamic generated config classes look a lot like what I've been doing here: https://github.com/lebrice/SimpleParsing/pull/156
Just hang tight for now :sweat_smile: I'll figure out a solution to this soon, and get back to you.
In the case of my Partial
feature, it would look like this:
# Dynamically create a dataclass that will be used for the above type:
# NOTE: We could use Partial[Adam] or Partial[Optimizer], however this would treat `params` as a
# required argument.
# AdamConfig = Partial[Adam] # would treat 'params' as a required argument.
# SGDConfig = Partial[SGD] # same here
AdamConfig: type[Partial[Adam]] = config_dataclass_for(Adam, ignore_args="params")
SGDConfig: type[Partial[SGD]] = config_dataclass_for(SGD, ignore_args="params")
@dataclass
class Config:
# Which optimizer to use.
optimizer: Partial[Optimizer] = subgroups(
{
"sgd": SGDConfig,
"adam": AdamConfig,
},
default_factory=lambda: AdamConfig(lr=3e-4),
)
I'm going to close this issue, since I believe it was fixed by #185 , but I'll create a new one for re-allowing subgroups to have more flexible options. Thanks again for posting @Tomiinek !
Well I'll actually first add a test to make sure that this issue is fixed. THEN, I'll close this issue :)
Great! Thank you both guys.
Hey @Tomiinek , I've now made this issue #195 and PR here #196 that adds the ability to use functools.partial
objects.
This isn't exactly what you're looking for, however, the PR should also have re-allowed dynamically created dataclasses to be used in the subgroups dict.
Let me know what you think! :)
Hey, that's a really nice solution! :slightly_smiling_face: Pls merge ASAP :sweat_smile:
Describe the bug I would like to have two subgroups with a different destinations, let's say
config
andconfig2
. However, if one of them is a substring of another, an exception is raised saying:To Reproduce
Expected behavior Not to rise the exception :slightly_smiling_face:
I am not sure, but maybe this condition https://github.com/lebrice/SimpleParsing/blob/545aa469defecba6d30da02199a1ea206740ad23/simple_parsing/parsing.py#L768-L771
should be changed to: