lebrice / SimpleParsing

Simple, Elegant, Typed Argument Parsing with argparse
MIT License
384 stars 46 forks source link

Flatten nested configs into single-depth dictionary with absolute path as key #237

Closed ludwigwinkler closed 1 year ago

ludwigwinkler commented 1 year ago

Thank you for the amazing package.

For my use case I'm missing something akin to a flatten function for nested configs, that concatenates the layers of the nested configs to an absolute path. I have the following nested config structure which I can readily access in my PyTorch Lightning model: image

Describe the solution you'd like Is there a flatten function which returns a single dictionary with the absolute path to each argument as the key 'model.num_layers' and the value as the corresponding value in the nested config. For example something like (subsampled from the given nested config file above)

{'model.num_layers': 3, 'data.batch_size': 64, 'optimization.lr': 0.001, ... }

Describe alternatives you've considered When I type 'python my_script_abc.py --help' a sort of nested structure occurs in which in case of argument conflict, a prefix is added. That makes me hope that some solution already exists that can easily be retooled for this feature.

Additional context I'm using it for WandB, where the hyperparamter configuration requires pure strings. So I need pairs of strings like 'optimization.lr', 'model.num_layers' which I can pair with arguments.

Currenlty, WandB is reducing the nested configs into a single string without revealing the nested arguments: image

ludwigwinkler commented 1 year ago

I hacked together a solution for my needs:

# System 
import os, sys, copy
from pathlib import Path

# Python
import numpy as np
import matplotlib.pyplot as plt
import functools, itertools

# Bookkeeping
from dataclasses import dataclass
import simple_parsing
from simple_parsing import ArgumentParser
import pprint
from typing import *

# a bunch of data classes

@dataclass
class Config:
    '''General Config for General Stuff'''
    logging: str = ["disabled", "online"][1]
    seed: int = 12345
    fast_dev_run: int = 0

@dataclass
class OptimizerConfig:
    '''Optimizer Config'''
    constructor = 'Adam'
    lr: float = 1e-3

@dataclass
class TrainerConfig:
    '''Lightning Trainer Config'''
    max_epochs: int = 5

@dataclass
class MNISTDataSetConfig:
    '''MNIST DataSet Config'''
    datamodule: str = "mnist"
    data_dir: str = "data/mnist"
    download: bool = True
    batch_size: int = 64
    num_classes: int = 10
    resize: tuple = (28, 28)

@dataclass
class TinyDebugNetConfig:
    num_layers: int = 3
    num_hidden: int = 100
    another_config: Config = Config()

@dataclass
class MNISTTinyDebugNetConfig:
    '''Parent Config for Tiny Network with MNIST'''
    stuff: str = "abd"
    config: Config = Config()
    model: TinyDebugNetConfig = TinyDebugNetConfig()
    data: MNISTDataSetConfig = MNISTDataSetConfig()
    optimization: OptimizerConfig = OptimizerConfig()
    trainer: TrainerConfig = TrainerConfig()

# recursively going through nested dictionaries
def flatten_config(config):
    def abs_path_nested_dict(dictionary, hparams_dict: dict, my_keys=""):
        for key, value in dictionary.items():
            current_key = my_keys + "." + key
            if type(value) is dict:
                abs_path_nested_dict(value, hparams_dict, current_key)
            else:
                hparams_dict.update({current_key[1:]: value})

    nested_config_dict = simple_parsing.helpers.serialization.to_dict(config)
    flattened_config_dict = {}
    abs_path_nested_dict(nested_config_dict, flattened_config_dict)
    return flattened_config_dict

# pretty print nested config using the first layer as an organizing layer and printing full path
def print_config(config):
    flattened_config = flatten_config(config)
    iterator = itertools.groupby(
        flattened_config.items(), lambda keyvalue: keyvalue[0].split(".")[0]
    )

    print()
    print("Config")
    for i_key, (key, group) in enumerate(copy.deepcopy(iterator)):
        print(key)
        right_tabs = max([len(key__) for key__, _ in copy.deepcopy(group)]) + 1
        for key_, val in group:
            print(f"\t {key_:{right_tabs}}: {type(val).__name__} = {val}")
    print()

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument(
        "--config", default="mnisttinydebugnet"
    )
    parser.add_arguments(MNISTTinyDebugNetConfig, dest="hparams")
    hparams = parser.parse_args().hparams

    print_config(hparams)
        [print(key, ' : ', value) for key, value in flatten_config(hparams).items()]
zhiruiluo commented 1 year ago

Check this out: https://github.com/lebrice/SimpleParsing/blob/3576a12d9f14036bc30850e6e46eee309bfaeb90/simple_parsing/helpers/flatten.py#L9-L19

https://github.com/lebrice/SimpleParsing/blob/3576a12d9f14036bc30850e6e46eee309bfaeb90/test/utils/test_flattened.py#L126-L136

To iterate through all flatten pairs:

c = Config()
for i in c.attributes():
    print(i)
lebrice commented 1 year ago

Hello there @ludwigwinkler , thanks for posting this!

:raised_eyebrow: sorry @zhiruiluo, but I don't think this is what @ludwigwinkler is looking for.

@ludwigwinkler your solution is correct: logging dictionaries to wandb, instead of dataclasses, is the way to go. But I might have a simpler suggestion: What about using dataclasses.asdict? (This is what I do in my own experiments).

wandb.config["hparams"] = dataclasses.asdict(hparams)
lebrice commented 1 year ago

LMK if you have any questions. I'll close this issue for now if that's alright with you.

Thanks for posting!