lxa9867 / ControlVAR

This is the official implementation for ControlVAR.
50 stars 1 forks source link

Missing functions in directoty “models” #10

Closed firethatpotato closed 2 weeks ago

firethatpotato commented 3 weeks ago

Thanks for your great work but I have some troubles.

When I run "train_control_var_hpu.py" to validate using command in "Inference" in README, some errors indicating that there are some missing functions or classes in models.

Traceback (most recent call last): File "/home/N3_3090U2/anaconda3/envs/controlvar/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 76, in _wrap fn(i, *args) File "/home/N3_3090U2/Desktop/controlVAR/ControlVAR-main/train_control_var_gpu.py", line 595, in process var = build_control_var(vae=vqvae, depth=args.depth, patch_nums=args.v_patch_nums, mask_type=args.mask_type, File "/home/N3_3090U2/Desktop/controlVAR/ControlVAR-main/models/__init__.py", line 37, in build_control_var return MaskVAR( **NameError: name 'MaskVAR' is not defined** File "/home/N3_3090U2/Desktop/controlVAR/ControlVAR-main/train_control_var_gpu.py", line 697, in run mp.spawn(process, File "/home/N3_3090U2/Desktop/controlVAR/ControlVAR-main/train_control_var_gpu.py", line 704, in <module> run(process, args.gpus, args) torch.multiprocessing.spawn.ProcessRaisedException:

'VisualProgressAutoreg' cannot be found, either.

lxa9867 commented 3 weeks ago

Hi, please remove the VisualProgressAutoreg class as it was not released. It was not used in the pipeline.

firethatpotato commented 3 weeks ago

Hi, please remove the VisualProgressAutoreg class as it was not released. It was not used in the pipeline.

Thank you! How about the function "MaskVAR"? This function is the return of "build_control_var", which is used to initiate var:

train_control_var_hpu.py:

var = build_control_var(vae=vqvae, depth=args.depth, patch_nums=args.v_patch_nums, mask_type=args.mask_type,
                         cond_drop_rate=1.1 if args.uncond else 0.1, bidirectional=args.bidirectional,
                         separate_decoding=args.separate_decoding, separator=args.separator, type_pos=args.type_pos,
                         indep=args.indep, multi_cond=args.multi_cond)

models init.py:

def build_control_var(
    vae: VQVAE, depth: int,
    patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16),   # 10 steps by default
    aln=1, aln_gamma_init=1e-3, shared_aln=False, layer_scale=-1,
    tau=4, cos_attn=False,
    flash_if_available=True, fused_if_available=True,
    mask_type='replace', cond_drop_rate=0.1, bidirectional=False, separate_decoding=False, separator=False,
    type_pos=False, indep=False, multi_cond=False,
):
    if mask_type == 'replace':
        mask_factor = 1
    elif mask_type == 'interleave_append':
        mask_factor = 2
    else:
        raise NotImplementedError
    return **MaskVAR**(
        vae_local=vae, patch_nums=patch_nums,
        depth=depth, embed_dim=depth*64, num_heads=depth, drop_path_rate=0.1 * depth/24,
        aln=aln, aln_gamma_init=aln_gamma_init, shared_aln=shared_aln, layer_scale=layer_scale,
        tau=tau, cos_attn=cos_attn, cond_drop_rate=cond_drop_rate,
        flash_if_available=flash_if_available, fused_if_available=fused_if_available, mask_factor=mask_factor,
        bidirectional=bidirectional, separate_decoding=separate_decoding, separator=separator, type_pos=type_pos,
        indep=indep, multi_cond=multi_cond,
    )

"MaskVAR" cannot be found.

lxa9867 commented 3 weeks ago

Sorry about that. This should be a typo, please replace it as ControlVAR

firethatpotato commented 3 weeks ago

Sorry about that. This should be a typo, please replace it as ControlVAR

It works. Thanks for your reply!

lxa9867 commented 2 weeks ago

Cool! Closing this issue.