MichalZnalezniak / Contrastive-Hierarchical-Clustering

This is the official code for the paper Contrastive Hierarchical Clustering (ECML PKDD 2023)
15 stars 3 forks source link

Mask typer error #9

Closed kroegern1 closed 8 months ago

kroegern1 commented 8 months ago

https://github.com/MichalZnalezniak/Contrastive-Hierarchical-Clustering/blob/2c418fcc763081e89073d83d648308d277b1fd10/tree_losses.py#L31C1-L31C28

Hello, I tried running the code but I get this error:

Train Epoch: [1/4] Loss: 5.3809: 100%|██████████| 195/195 [02:26<00:00,  1.34it/s]
Train Epoch: [2/4] Loss: 5.1769: 100%|██████████| 195/195 [02:31<00:00,  1.29it/s]
  0%|          | 0/195 [00:33<?, ?it/s]
Traceback (most recent call last):
  File "/Contrastive-Hierarchical-Clustering-main/main.py", line 195, in <module>
    total_loss, tree_loss_train, reg_loss_train, simclr_loss_train = train(model, train_loader, optimizer, epoch, cfg, device)
  File "/Contrastive-Hierarchical-Clustering-main/main.py", line 47, in train
    tree_loss_value = tree_loss(tree_output1, tree_output2, batch_size, net.masks_for_level, mean_of_probs_per_level_per_epoch, cfg.tree.tree_level, device)
  File "/Contrastive-Hierarchical-Clustering-main/tree_losses.py", line 36, in tree_loss
    labels = labels * ~mask_for_level
TypeError: bad operand type for unary ~: 'dict'

What is the purpose of this mask here? The mask is a dict, what is the intent?

MichalZnalezniak commented 8 months ago

Hi @kroegern1, This mask from line 30 is a tensor. For batch_size 4 it looks like this:

        [0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0.]]

It is used to distinguish the indexes of positive and negative pairs, when calculating the hierarchical clustering loss in lines 41, 43 https://github.com/MichalZnalezniak/Contrastive-Hierarchical-Clustering/blob/main/tree_losses.py#L24C1-L43C198

mask_for_level from line 36 is a dict and it is used for pruning, We described the pruning strategy in section 3.5 in our paper https://arxiv.org/pdf/2303.03389.pdf

MichalZnalezniak commented 8 months ago

Have you made any changes to the code? The error messages show that error happens in the line 36 labels = labels * ~mask_for_level which looks really similar to line 31 labels = labels * ~mask, but in line 36 on the main branch we have prob_features = prob_features * mask_for_level[level]

kroegern1 commented 8 months ago

Wow great catch, thank you so much. It accidentally autocompleted when I hit tab at the end of the line: "labels = labels ~mask" which caused it to chang to "labels = labels ~mask_for_level" I'll run it again, thank you for getting back to me quickly!

MichalZnalezniak commented 8 months ago

Please let me know if the problem has been resolved. If it has, can we close the issue?