zhmiao / OpenLongTailRecognition-OLTR

Pytorch implementation for "Large-Scale Long-Tailed Recognition in an Open World" (CVPR 2019 ORAL)
BSD 3-Clause "New" or "Revised" License
839 stars 128 forks source link

Code Error for Training (Stage 1) #31

Closed BehzadBozorgtabar closed 4 years ago

BehzadBozorgtabar commented 5 years ago

Hello, When I run python main.py --config ./config/ImageNet_LT/stage_1.py, there is an error.

Loading Dot Product Classifier. Traceback (most recent call last): File "/media/Elements/OLTR/OpenLongTailRecognition-OLTR/main.py", line 55, in training_model = model(config, data, test=False) File "/media/Elements/OLTR/OpenLongTailRecognition-OLTR/run_networks.py", line 26, in init self.init_models() File "/media/Elements/OLTR/OpenLongTailRecognition-OLTR/run_networks.py", line 69, in init_models self.networks[key] = source_import(def_file).create_model(*model_args) File "./models/DotProductClassifier.py", line 16, in create_model clf = DotProduct_Classifier(num_classes, feat_dim) File "./models/DotProductClassifier.py", line 8, in init self.fc = nn.Linear(feat_dim, num_classes) File "/home/.local/lib/python3.5/site-packages/torch/nn/modules/linear.py", line 81, in init self.reset_parameters() File "/home/.local/lib/python3.5/site-packages/torch/nn/modules/linear.py", line 84, in reset_parameters init.kaiminguniform(self.weight, a=math.sqrt(5)) File "/home/.local/lib/python3.5/site-packages/torch/nn/init.py", line 325, in kaiminguniform std = gain / math.sqrt(fan) ZeroDivisionError: float division by zero

Could you give some advice to solve this problem?

zhmiao commented 5 years ago

Hello @BehzadBozorgtabar , thanks for asking. Could you please print out feat_dim and num_classes in this case? We are not able to reproduce this error, but suspecting there might be some problem of reading the parameters. Maybe because of the versions of python, reading configuration numbers from files can have some order issue. Thanks again.

iwzy7071 commented 3 years ago

I've met the same error. net = GraphSAGE(dataset.num_features, dataset.num_classes) where dataset.num_features equals zero, the problem happens