lightly-ai / lightly

A python library for self-supervised learning on images.
https://docs.lightly.ai/self-supervised-learning/
MIT License
2.92k stars 250 forks source link

ValueError: n_samples=128 should be >= n_clusters=300. RuntimeError: mat1 and mat2 shapes cannot be multiplied (256x128 and 76800x128). #1642

Open CharisWg opened 1 week ago

CharisWg commented 1 week ago

When I implemented Smog.py and ran it until the 4th epoch, I encountered errors after setting self.n_groups = 300: Traceback (most recent call last): File "/datashare3/charis/code2/SSL-yolo8/lightly-master/examples/pytorch/smog_yolo.py", line 230, in smogmodel.reset_group_features(memory_bank=memory_bank) File "/datashare3/charis/code2/SSL-yolo8/lightly-master/examples/pytorch/smog_yolo.py", line 59, in reset_group_features group_features = self._cluster_features(features.t()) File "/datashare3/charis/code2/SSL-yolo8/lightly-master/examples/pytorch/smog_yolo.py", line 51, in _cluster_features kmeans = KMeans(self.n_groups).fit(features) File "/datashare3/charis/anaconda/envs/yolov8SSL/lib/python3.9/site-packages/sklearn/base.py", line 1473, in wrapper return fit_method(estimator, *args, **kwargs) File "/datashare3/charis/anaconda/envs/yolov8SSL/lib/python3.9/site-packages/sklearn/cluster/_kmeans.py", line 1473, in fit self._check_params_vs_input(X) File "/datashare3/charis/anaconda/envs/yolov8SSL/lib/python3.9/site-packages/sklearn/cluster/_kmeans.py", line 1414, in _check_params_vs_input super()._check_params_vs_input(X, default_n_init=10) File "/datashare3/charis/anaconda/envs/yolov8SSL/lib/python3.9/site-packages/sklearn/cluster/_kmeans.py", line 878, in _check_params_vs_input raise ValueError( ValueError: n_samples=128 should be >= n_clusters=300.

After setting self.n_groups = 128, another error arose: Traceback (most recent call last): File "/datashare3/charis/code2/SSL-yolo8/lightly-master/examples/pytorch/smog_yolo.py", line 243, in assignments = smogmodel.smog.assign_groups(x1_encoded) File "/datashare3/charis/anaconda/envs/yolov8SSL/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) File "/datashare3/charis/code2/SSL-yolo8/lightly-master/lightly/models/modules/heads.py", line 401, in assign_groups return torch.argmax(self.forward(x, self.group_features), dim=-1) File "/datashare3/charis/code2/SSL-yolo8/lightly-master/lightly/models/modules/heads.py", line 362, in forward logits = torch.mm(x, group_features.t()) RuntimeError: mat1 and mat2 shapes cannot be multiplied (256x128 and 76800x128).

guarin commented 1 week ago

Hi, could you share the relevant code?

It looks like the memory bank features need to be transposed. The shapes should be 256x128 and 128x76800. You can transpose them by setting MemoryBank(feature_dim_first=False) or feature_dim_first=True depending on your code.