Mq-Zhang1 / HOIDiffusion

Official Code Release for HOIDiffusion (CVPR 2024)
MIT License
30 stars 1 forks source link

Using three adapter models at the same time is not very effective #4

Open LDS666888 opened 2 months ago

LDS666888 commented 2 months ago

The resulting effect is shown below: ToyCar_0_2_6 ToyCar_0_2_4

I see the get_adapters function in inference_base.py uses adapter['model'] = CoAdapter(w1 = 1, w2 = 1, w3 = 1).to(opt.device), So I made the following changes to the function:

def get_adapters(opt, cond_type: ExtraCondition):
    adapter = {}
    cond_weight = getattr(opt, f'{cond_type.name}_weight', None)
    if cond_weight is None:
        cond_weight = getattr(opt, 'cond_weight')
    adapter['cond_weight'] = cond_weight

    adapter['model'] = CoAdapter(w1 = 1, w2 = 1, w3 = 1).to(opt.device)

    ckpt_pose_path ="F:/data_enhancement/HOIDiffusion-main/midas_models/t2iadapter_openpose_sd14v1.pth" #getattr(opt, f'{cond_type.name}_adapter_ckpt', None)
    ckpt_depth_path="F:/data_enhancement/HOIDiffusion-main/midas_models/t2iadapter_depth_sd14v1.pth"
    ckpt_mask_path="F:/data_enhancement/HOIDiffusion-main/midas_models/t2iadapter_seg_sd14v1.pth"

    #print(ckpt_path)
    # if ckpt_path is None:
    #     ckpt_path = getattr(opt, 'adapter_ckpt')
    #print(ckpt_path)
    state_dict_pose = read_state_dict(ckpt_pose_path)
    state_dict_depth=read_state_dict(ckpt_depth_path)
    state_dict_mask=read_state_dict(ckpt_mask_path)
    new_state_dict_pose = {}
    new_state_dict_depth = {}
    new_state_dict_mask = {}
    new_state_dict={}
    for k, v in state_dict_pose.items():
        if k.startswith('adapter.'):
            new_state_dict_pose[k[len('adapter.'):]] = v
        else:
            new_state_dict_pose[k] = v

    for k, v in state_dict_depth.items():
        if k.startswith('adapter.'):
            new_state_dict_depth[k[len('adapter.'):]] = v
        else:
            new_state_dict_depth[k] = v

    for k, v in state_dict_mask.items():
        if k.startswith('adapter.'):
            new_state_dict_mask[k[len('adapter.'):]] = v
        else:
            new_state_dict_mask[k] = v

    # 如果某些键名没有前缀,可以手动添加
    for k, v in state_dict_pose.items():
        if not k.startswith('pose_ada.'):
            new_state_dict_pose['pose_ada.' + k] = v
            del new_state_dict_pose[k]

    # 如果某些键名没有前缀,可以手动添加
    for k, v in state_dict_depth.items():
        if not k.startswith('depth_ada.'):
            new_state_dict_depth['depth_ada.' + k] = v
            del new_state_dict_depth[k]

    # 如果某些键名没有前缀,可以手动添加
    for k, v in state_dict_mask.items():
        if not k.startswith('mask_ada.'):
            new_state_dict_mask['mask_ada.' + k] = v
            del new_state_dict_mask[k]

    # 合并 pose 状态字典
    for k, v in new_state_dict_pose.items():
        new_state_dict[k] = v  # 直接添加键值对到 new_state_dict
    # 合并 depth 状态字典
    for k, v in new_state_dict_depth.items():
        new_state_dict[k] = v  # 直接添加键值对到 new_state_dict
      # 合并 mask 状态字典
    for k, v in new_state_dict_mask.items():
        new_state_dict[k] = v  # 直接添加键值对到 new_state_dict

    #print(new_state_dict)
    #print(adapter['model'])
    adapter['model'].load_state_dict(new_state_dict)

    return adapter

The parameters I select when debugging are as follows:

--which_cond
dex
--bs
1
--cond_weight
1
--sd_ckpt
"F:\data_enhancement\HOIDiffusion-main\stable_difussion_model\sd-v1-4.ckpt"
--cond_tau
1
--cond_inp_type
image
--input
"F:\data_enhancement\HOIDiffusion-main\output\depth"
--file
"F:\data_enhancement\HOIDiffusion-main\output\train.csv"
--outdir
"F:\data_enhancement\HOIDiffusion-main\test_dex_outdir"

Please tell me how to solve the above problem, thank you

Mq-Zhang1 commented 2 months ago

If I don't understand wrong, the condition models used in this script are from the checkpoints in t2i-adapter? The openpose keypoint arrangement might be different. Please refer to the DexYCB keypoint layout. Besides we adopted normal maps instead of the depth for training. And the model is not trained based on these released condition checkpoints. It may not work when directly using them.

LDS666888 commented 2 months ago

Okay, thank you.