Closed kapilb7 closed 3 years ago
Hi @kapilb7 , I believe this is due to the hard-coding of the dataset keys in data.py. You should be able to get it working by replacing lines 118-125 in data.py with:
datasets = {'negative': [], 'positive': []}
for l in self.dataset:
datasets[l.split()[2]].append(l)
self.datasets = [
datasets['negative'],
datasets['positive'],
]
print(len(self.datasets[0]), len(self.datasets[1]))
You would also need to pass an appropriate mapping
and list of class_weights
to the BalanceCovidDataset
constructor, for example:
mapping = {'negative': 0, 'positive': 1}
class_weights = [1., 4.]
Also note that train_tf.py is only set up for multiclass classification at the moment, and so it may not work for binary training.
Hi, I just tried with 'only' these modifications and it didn't work, as you said, I think data.py isn't updated to work for Binary classification. Since I'm a novice, hope someone can help me out with it... This is what I'm getting now:
(base) kapil@Kapils-MacBook-Pro COVIDNet-CXR % python3 train_tf.py
Output: /Users/kapil/Documents/FYP/COVIDNet-CXRCOVIDNet-lr0.0002
Traceback (most recent call last):
File "train_tf.py", line 50, in
Hi @kapilb7 , The data.py script is now updated to handle binary labels of positive and negative, this is handled in the train_tf.py script through the --n_classes flag. For example, to specify negative/positive classification set --n_classes to 2 and to switch to the normal/pneumonia/COVID-19 multi-class classification set --n_classes to 3.
Cool! I haven't and can't try it for a few days, so I'll close it anyways.
I wanted to create a binary classifier, so I used create_COVIDx_binary.ipynb to create the dataset accordingly. But when I tried creating a model with train_tf.py and used the train_COVIDx7B.txt file for training, I'm getting this error:
(base) kapil@Kapils-MacBook-Pro COVIDNet-CXR % python3 train_tf.py Output: /Users/kapil/Documents/FYP/COVIDNet-CXRCOVIDNet-lr0.0002 Traceback (most recent call last): File "train_tf.py", line 50, in
generator = BalanceCovidDataset(data_dir=args.datadir,
File "/Users/kapil/Documents/FYP/COVIDNet-CXR/data.py", line 120, in init
datasets[l.split()[2]].append(l)
KeyError: 'negative'