valeoai / STEEX

STEEX: Steering Counterfactual Explanations with Semantics
Other
18 stars 3 forks source link

error on generate_counterfactuals #3

Closed huifenghc closed 3 months ago

huifenghc commented 11 months ago

hello!

When I tried to reproduce the results on the CelebAMaskHQ dataset, I got an error on command(python generate_counterfactuals.py --dataset_name celebamhq --checkpoints_dir path/to/checkpoints/ --dataroot path/to/dataset/ --name_exp exp_celebamhq --target_attribute 1 )

my error:

Traceback (most recent call last): File "generate_counterfactuals.py", line 109, in reconstructed_query_image = generator(data_i, mode='inference').detach().cpu().float().numpy() File "/hd12/anaconda3/envs/STEEX/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, **kwargs) File "/hd12/projects/STEEX-main/models/pix2pix_model.py", line 42, in forward input_semantics, real_image = self.preprocess_input(data) File "/hd12/projects/STEEX-main/models/pix2pix_model.py", line 150, in preprocess_input input_semantics = inputlabel.scatter(1, label_map, 1.0) RuntimeError: index 19 is out of bounds for dimension 1 with size 19

the information for input_label and label_map is as follows:

input_label.shape: torch.Size([1, 19, 256, 256]) label_map.shape: torch.Size([1, 1, 256, 256]) label_map: tensor([[[[19, 19, 19, ..., 19, 19, 19], [19, 19, 19, ..., 19, 19, 19], [19, 19, 19, ..., 19, 19, 19], ..., [19, 19, 19, ..., 12, 12, 12], [19, 19, 19, ..., 12, 12, 12], [19, 19, 19, ..., 12, 12, 12]]]])

I guess there is an error in the label_map,but I'm not sure what's causing this problem.

looking forward to receiving your help.

EloiZ commented 11 months ago

Hello, there are 19 semantic classes for CelebA/CelebaMask-HQ and 20 semantic classes for BDD It seems that the mask you have for CelebAMask-HQ have up to 20 semantic classes. Can you check on that front please? Best,

huifenghc commented 11 months ago

Thank you for your reply.

The semantic_n parameter of the CelebaMask-HQ dataset in the advanced_options.py file has been set to 19, and I did not change this setting.

Is it possible that I have problems generating predicted_masks using infer_masks.py? Or do you know of any other possible causes?

EloiZ commented 11 months ago

Indeed, I think the problem comes from the infered masks that must have up to 19 classes (values from from 0 to 18 in the mask) for CelebaA / CelebAMask-HQ images. Did you uncomment the config block corresponding to CelebAMask-HQ in the infer_masks.py (as the default one is for BDD)? did you visualize the masks?

huifenghc commented 11 months ago

When generating predicted_masks, I have already uncommented the code related to CelebAHamask.

The settings are as follows:

# FOR CELEBAMASK-HQ
dataroot = '/hd12/projects/STEEX-main/path/to/dataset/CelebAMask-HQ'
segmentation_network_name = 'deeplabv3_celebamhq'
dataset_mode = 'celebamhq'
save_dir_masks = '/hd12/projects/STEEX-main/path/to/dataset/CelebAMask-HQ/CelebAMask-HQ/test/predicted_masks'
n_classes = 19

predicted_mask : 1000

And I wonder, after running g mask.py, do I still need to run the v mask.py file?(To preprocess the masks for CelebAMask-HQ, please follow the intructions provided here