FusionBrainLab / HairFastGAN

Official Implementation for "HairFastGAN: Realistic and Robust Hair Transfer with a Fast Encoder-Based Approach"
https://airi-institute.github.io/HairFastGAN/
MIT License
413 stars 60 forks source link

KeyError: 'ACE' #9

Closed LinMu7177 closed 4 months ago

LinMu7177 commented 4 months ago

First of all, thank you for open-sourcing such an outstanding project. When I tried to run the blending_gen.py script to generate data, I encountered the following error:

/home/wenwen/anaconda3/envs/hair/bin/python /home/wenwen/workspace/HairFastGAN/scripts/blending_gen.py Loading StyleGAN2 from checkpoint: ../pretrained_models/StyleGAN/ffhq.pt Loading e4e over the pSp framework from checkpoint: ../pretrained_models/encoder4editing/e4e_ffhq_encode.pt Network [SPADEGenerator] was created. Total number of parameters: 266.9 million. To see the architecture, do print(network). 0it [00:01, ?it/s] Traceback (most recent call last): File "/home/wenwen/workspace/HairFastGAN/scripts/blending_gen.py", line 80, in main(args) File "/home/wenwen/workspace/HairFastGAN/scripts/blending_gen.py", line 64, in main align_shape, align_color, name_to_embed = hair_fast(pt1, pt2, pt3, align_flag=True) File "/home/wenwen/workspace/HairFastGAN/hair_swap.py", line 105, in call return self.swap(*args, kwargs) File "/home/wenwen/workspace/HairFastGAN/hair_swap.py", line 97, in swap final_image = self.__swap_from_tensors(*images, seed=seed, benchmark=benchmark, exp_name=exp_name, *kwargs) File "/home/wenwen/workspace/HairFastGAN/utils/seed.py", line 28, in wraps result = func(args, kwargs) File "/home/wenwen/workspace/HairFastGAN/utils/time.py", line 34, in wraps return func(args, kwargs) File "/home/wenwen/workspace/HairFastGAN/hair_swap.py", line 51, in __swap_from_tensors align_shape = self.align.align_images('face', 'shape', name_to_embed, kwargs) File "/home/wenwen/workspace/HairFastGAN/scripts/blending_gen.py", line 35, in wrapper return func(args, kwargs) File "/home/wenwen/anaconda3/envs/hair/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context return func(*args, *kwargs) File "/home/wenwen/workspace/HairFastGAN/models/Alignment.py", line 130, in align_images gen1_sean = decode_sean(self.sean_model, img1_code.unsqueeze(0), target_mask) File "/home/wenwen/workspace/HairFastGAN/models/sean_codes/models/pix2pix_model.py", line 324, in decode_sean generated = sean_model(data, mode='UI_mode')[0] # [3, 256, 256] File "/home/wenwen/anaconda3/envs/hair/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(input, kwargs) File "/home/wenwen/workspace/HairFastGAN/models/sean_codes/models/pix2pix_model.py", line 45, in forward input_semantics, real_image = self.preprocess_input(data) File "/home/wenwen/workspace/HairFastGAN/models/sean_codes/models/pix2pix_model.py", line 133, in preprocess_input if data['obj_dic'][str(idx)]['ACE'].device.type == 'cpu': KeyError: 'ACE'

Process finished with exit code 1

As shown in the figure, further investigation revealed that after executing _img1_code, img2_code = encode_sean(self.sean_model, images, labels)_, the result in img1_code contains rows that are entirely zeros. This leads to a failure in passing the check in decodesean() during if not torch.all(cur_code == 0): obj_dic[str(idx)]['ACE'] = cur_code_.

Finally, an error occurs when running '_if data['obj_dic'][str(idx)]['ACE'].device.type == 'cpu':_.

How can I fix this bug?

Thanks!

image
LinMu7177 commented 4 months ago

I have removed the condition checking for cur_code == 0 in the decode_sean() method. Is this appropriate? I would appreciate any advice!

def decode_sean(sean_model, image_code, target_mask): obj_dic = load_average_feature()

for idx in range(19):
    cur_code = image_code[0, idx]
    **_# if not torch.all(cur_code == 0):
    #     obj_dic[str(idx)]['ACE'] = cur_code
    obj_dic[str(idx)]['ACE'] = cur_code_**

temp_face_image = torch.zeros((0, 3, 256, 256))  # place holder

data = {'label': target_mask,
        'instance': torch.tensor(0),
        'image': temp_face_image.clone().detach(),
        'obj_dic': obj_dic}
change_status(sean_model, 'UI_mode')
generated = sean_model(data, mode='UI_mode')[0]  # [3, 256, 256]
return generated
maximkm commented 4 months ago

Hello, it seems that the error occurs because load_average_feature() does not load median latents in decode_sean. These latents were added by the second commit: https://github.com/AIRI-Institute/HairFastGAN/commit/fc6d9e4e0a04fd5a3084f45e67972df65423f08d.

Possibly you just have an outdated version of the repository?

This function is responsible for inpaint after hair transfer, so it is very important and you should definitely not remove condition checking.

maximkm commented 4 months ago

The error is also possible if you run blending_gen.py from the scripts folder.

All scripts should be run from the root folder of the project, for example: python scripts/blending_gen.py