WenkeHuang / RethinkFL

CVPR2023 - Rethinking Federated Learning with Domain Shift: A Prototype View
81 stars 11 forks source link

Crash at hierarchical_info_loss on office dataset #4

Open smart0eddie opened 3 months ago

smart0eddie commented 3 months ago

Hi At the first line of hierarchical_info_loss f_pos = np.array(all_f)[all_global_protos_keys == label.item()][0].to(self.device)

The debugger reports Exception has occurred: ValueError setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (10,) + inhomogeneous part.

I think it is due to different classes having different number of clusters Partition 0: 3 clusters Partition 0: 2 clusters Partition 0: 3 clusters Partition 0: 3 clusters Partition 0: 2 clusters Partition 0: 3 clusters Partition 0: 3 clusters Partition 0: 2 clusters Partition 0: 1 clusters Partition 0: 2 clusters

fl_digits also encounter the same problem Partition 0: 3 clusters Partition 0: 2 clusters Partition 0: 4 clusters Partition 0: 4 clusters Partition 0: 4 clusters Partition 0: 2 clusters Partition 0: 5 clusters Partition 1: 2 clusters Partition 0: 2 clusters Partition 0: 4 clusters Partition 1: 2 clusters Partition 0: 3 clusters

May you help me with this issue? Thanks

bzHunter commented 1 month ago

i meet the same problem, Have you solved it?

smart0eddie commented 1 month ago

sadly, no The other three comparison methods works fine

bzHunter commented 1 month ago

sadly, no The other three comparison methods works fine

I solved it! You may get help at https://stackoverflow.com/questions/18665873/filtering-a-list-based-on-a-list-of-booleans.

smart0eddie commented 1 month ago

sadly, no The other three comparison methods works fine

I solved it! You may get help at https://stackoverflow.com/questions/18665873/filtering-a-list-based-on-a-list-of-booleans.

May you share how you modify the code? So you firstly select elements from the list and then convert them to tensor?

thanks

smart0eddie commented 2 weeks ago

This should work I search for the corresponding prototype index separately (actually the prototype should be sorted before use)

f_idx = np.where(all_global_protos_keys == label.item())[0][0]
f_pos = all_f[f_idx].to(self.device)
f_neg = torch.cat([f for i, f in enumerate(all_f) if i != f_idx]).to(self.device)

mean_f_pos = mean_f[f_idx].to(self.device)