Open TMullerSG opened 1 year ago
Instead of passing arguments using CLI, you can do something like this too.
def create_argparser(): # This argparser is from inference/sampling, but for training also you can follow same approach
defaults = dict(
data_name = '', # DATASET NAME
data_dir="", # PATH TO DATASET DIR
clip_denoised=True,
num_samples=1,
batch_size=1,
use_ddim=False,
model_path="",
num_ensemble=5, #number of samples in the ensemble
gpu_dev = "0",
out_dir='./results/',
multi_gpu = None, #"0,1,2"
debug = True
)
defaults.update(model_and_diffusion_defaults())
return defaults
class CFG:
def __init__(self, arg_dict = create_argparser()):
for key, value in arg_dict.items():
setattr(self, key, arg_dict.get(key, value))
Now in the def main():
args = CFG()
defaults.update({k: v for k, v in model_and_diffusion_defaults().items() if k not in defaults}) Hi, i believe this is what you want to have, otherwise the value will be overwriten by those in the predefined values