MedicineToken / MedSegDiff

Medical Image Segmentation with Diffusion Model
MIT License
1.04k stars 159 forks source link

error in create_argparser #42

Open TMullerSG opened 1 year ago

TMullerSG commented 1 year ago

defaults.update({k: v for k, v in model_and_diffusion_defaults().items() if k not in defaults}) Hi, i believe this is what you want to have, otherwise the value will be overwriten by those in the predefined values

Saharsh1005 commented 1 year ago

Instead of passing arguments using CLI, you can do something like this too.

def create_argparser(): # This argparser is from inference/sampling, but for training also you can follow same approach
    defaults = dict(
        data_name = '', # DATASET NAME
        data_dir="",      # PATH TO DATASET DIR
        clip_denoised=True,
        num_samples=1,
        batch_size=1,
        use_ddim=False,
        model_path="",
        num_ensemble=5,      #number of samples in the ensemble
        gpu_dev = "0",
        out_dir='./results/',
        multi_gpu = None, #"0,1,2"
        debug = True
    )
    defaults.update(model_and_diffusion_defaults())

    return defaults

class CFG:
    def __init__(self, arg_dict = create_argparser()):
        for key, value in arg_dict.items():
            setattr(self, key, arg_dict.get(key, value))

Now in the def main():
args = CFG()