Open CharisWg opened 1 week ago
Hi, could you share the relevant code?
It looks like the memory bank features need to be transposed. The shapes should be 256x128
and 128x76800
. You can transpose them by setting MemoryBank(feature_dim_first=False)
or feature_dim_first=True
depending on your code.
When I implemented Smog.py and ran it until the 4th epoch, I encountered errors after setting self.n_groups = 300: Traceback (most recent call last): File "/datashare3/charis/code2/SSL-yolo8/lightly-master/examples/pytorch/smog_yolo.py", line 230, in
smogmodel.reset_group_features(memory_bank=memory_bank)
File "/datashare3/charis/code2/SSL-yolo8/lightly-master/examples/pytorch/smog_yolo.py", line 59, in reset_group_features
group_features = self._cluster_features(features.t())
File "/datashare3/charis/code2/SSL-yolo8/lightly-master/examples/pytorch/smog_yolo.py", line 51, in _cluster_features
kmeans = KMeans(self.n_groups).fit(features)
File "/datashare3/charis/anaconda/envs/yolov8SSL/lib/python3.9/site-packages/sklearn/base.py", line 1473, in wrapper
return fit_method(estimator, *args, **kwargs)
File "/datashare3/charis/anaconda/envs/yolov8SSL/lib/python3.9/site-packages/sklearn/cluster/_kmeans.py", line 1473, in fit
self._check_params_vs_input(X)
File "/datashare3/charis/anaconda/envs/yolov8SSL/lib/python3.9/site-packages/sklearn/cluster/_kmeans.py", line 1414, in _check_params_vs_input
super()._check_params_vs_input(X, default_n_init=10)
File "/datashare3/charis/anaconda/envs/yolov8SSL/lib/python3.9/site-packages/sklearn/cluster/_kmeans.py", line 878, in _check_params_vs_input
raise ValueError(
ValueError: n_samples=128 should be >= n_clusters=300.
After setting self.n_groups = 128, another error arose: Traceback (most recent call last): File "/datashare3/charis/code2/SSL-yolo8/lightly-master/examples/pytorch/smog_yolo.py", line 243, in
assignments = smogmodel.smog.assign_groups(x1_encoded)
File "/datashare3/charis/anaconda/envs/yolov8SSL/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/datashare3/charis/code2/SSL-yolo8/lightly-master/lightly/models/modules/heads.py", line 401, in assign_groups
return torch.argmax(self.forward(x, self.group_features), dim=-1)
File "/datashare3/charis/code2/SSL-yolo8/lightly-master/lightly/models/modules/heads.py", line 362, in forward
logits = torch.mm(x, group_features.t())
RuntimeError: mat1 and mat2 shapes cannot be multiplied (256x128 and 76800x128).