MeioJane / CHR

SIXray : A Large-scale Security Inspection X-ray Benchmark in CVPR 2019
80 stars 37 forks source link

read_object_labels not found in ray.py:97 #4

Open codestorm04 opened 5 years ago

codestorm04 commented 5 years ago

Actually it seems not defined globally within the hole projects, so maybe it's a piece of testing code before release?

xiying-wei commented 5 years ago

How did you solve the problem?

codestorm04 commented 5 years ago

Here is the git diffs in my repos

diff --git a/CHR/CHR/engine.py b/CHR/CHR/engine.py
index 2845706..1dd7694 100644
--- a/CHR/CHR/engine.py
+++ b/CHR/CHR/engine.py
@@ -13,7 +13,7 @@ from tqdm import tqdm
 import numpy as np

 from CHR.util import AveragePrecisionMeter, Warp
-
+from CHR.ray import read_image_label

 class Engine(object):
     def __init__(self, state={}):
diff --git a/CHR/CHR/main.py b/CHR/CHR/main.py
index e4a2533..7ed94be 100644
--- a/CHR/CHR/main.py
+++ b/CHR/CHR/main.py
@@ -72,7 +72,7 @@ def main_ray():
     global args, best_prec1, use_gpu
     args = parser.parse_args()

-    args.data='/DATA/disk1/mcj/dataset/'
+    args.data='/mnt/lyz/SIXray-data/'
     args.resume = './CHR/models-/checkpoint.pth.tar'

@@ -81,7 +81,7 @@ def main_ray():

     # define dataset
     train_dataset = XrayClassification(args.data, 'train')
-    val_dataset = XrayClassification(args.data, 'test_new')
+    val_dataset = XrayClassification(args.data, 'test')
     num_classes = 5

     # load model
diff --git a/CHR/CHR/ray.py b/CHR/CHR/ray.py
index 720472a..bab4fbb 100644
--- a/CHR/CHR/ray.py
+++ b/CHR/CHR/ray.py
@@ -85,18 +85,19 @@ class XrayClassification(data.Dataset):

         # define path of csv file
-        path_csv = os.path.join(self.root, 'ImageSet','train_test_10-5')
+        path_csv = os.path.join(self.root, 'ImageSet', '10')
         # define filename of csv file
         file_csv = os.path.join(path_csv,  set + '.csv')

         # create the csv file if necessary
         if not os.path.exists(file_csv):
-            if not os.path.exists(path_csv):  # create dir if necessary
-                os.makedirs(path_csv)
-            # generate csv file
-            labeled_data = read_object_labels(self.root, self.set)
-            # write csv file
-            write_object_labels_csv(file_csv, labeled_data)
+            # if not os.path.exists(path_csv):  # create dir if necessary
+            #     os.makedirs(path_csv)
+            # # generate csv file
+            # labeled_data = read_object_labels(self.root, self.set)
+            # # write csv file
+            # write_object_labels_csv(file_csv, labeled_data)
+            raise ValueError(file_csv + " not found.")

         self.classes = object_categories
         self.images = read_object_labels_csv(file_csv)

How did you solve the problem?

MeioJane commented 5 years ago

annotation in the project of SIXray

MeioJane commented 4 years ago

Here is the git diffs in my repos

diff --git a/CHR/CHR/engine.py b/CHR/CHR/engine.py
index 2845706..1dd7694 100644
--- a/CHR/CHR/engine.py
+++ b/CHR/CHR/engine.py
@@ -13,7 +13,7 @@ from tqdm import tqdm
 import numpy as np

 from CHR.util import AveragePrecisionMeter, Warp
-
+from CHR.ray import read_image_label

 class Engine(object):
     def __init__(self, state={}):
diff --git a/CHR/CHR/main.py b/CHR/CHR/main.py
index e4a2533..7ed94be 100644
--- a/CHR/CHR/main.py
+++ b/CHR/CHR/main.py
@@ -72,7 +72,7 @@ def main_ray():
     global args, best_prec1, use_gpu
     args = parser.parse_args()

-    args.data='/DATA/disk1/mcj/dataset/'
+    args.data='/mnt/lyz/SIXray-data/'
     args.resume = './CHR/models-/checkpoint.pth.tar'

@@ -81,7 +81,7 @@ def main_ray():

     # define dataset
     train_dataset = XrayClassification(args.data, 'train')
-    val_dataset = XrayClassification(args.data, 'test_new')
+    val_dataset = XrayClassification(args.data, 'test')
     num_classes = 5

     # load model
diff --git a/CHR/CHR/ray.py b/CHR/CHR/ray.py
index 720472a..bab4fbb 100644
--- a/CHR/CHR/ray.py
+++ b/CHR/CHR/ray.py
@@ -85,18 +85,19 @@ class XrayClassification(data.Dataset):

         # define path of csv file
-        path_csv = os.path.join(self.root, 'ImageSet','train_test_10-5')
+        path_csv = os.path.join(self.root, 'ImageSet', '10')
         # define filename of csv file
         file_csv = os.path.join(path_csv,  set + '.csv')

         # create the csv file if necessary
         if not os.path.exists(file_csv):
-            if not os.path.exists(path_csv):  # create dir if necessary
-                os.makedirs(path_csv)
-            # generate csv file
-            labeled_data = read_object_labels(self.root, self.set)
-            # write csv file
-            write_object_labels_csv(file_csv, labeled_data)
+            # if not os.path.exists(path_csv):  # create dir if necessary
+            #     os.makedirs(path_csv)
+            # # generate csv file
+            # labeled_data = read_object_labels(self.root, self.set)
+            # # write csv file
+            # write_object_labels_csv(file_csv, labeled_data)
+            raise ValueError(file_csv + " not found.")

         self.classes = object_categories
         self.images = read_object_labels_csv(file_csv)

How did you solve the problem?

I have update the code. You can git clone the new code thank you