NVIDIA / semantic-segmentation

Nvidia Semantic Segmentation monorepo
BSD 3-Clause "New" or "Revised" License
1.76k stars 388 forks source link

Why is loss a negative number? #181

Open xu19971109 opened 1 year ago

xu19971109 commented 1 year ago

image

This is my dataset.py: ` import os import os.path as osp from config import cfg from runx.logx import logx from datasets.base_loader import BaseLoader from datasets.utils import make_dataset_folder from datasets import uniform

  class Loader(BaseLoader):
      num_classes = 2
      ignore_label = 2
      trainid_to_name = {}
      color_mapping = []

      def __init__(self, mode, quality='semantic', joint_transform_list=None,
                   img_transform=None, label_transform=None, eval_folder=None):

          super(Loader, self).__init__(quality=quality,
                                       mode=mode,
                                       joint_transform_list=joint_transform_list,
                                       img_transform=img_transform,
                                       label_transform=label_transform)

          # root = cfg.DATASET.MAPILLARY_DIR
          root = cfg.DATASET.NEILIX_IMG_ROOT
          config_fn = os.path.join(root, 'config.json')
          # self.fill_colormap_and_names(config_fn)

          ######################################################################
          # Assemble image lists
          ######################################################################
          if mode == 'folder':
              self.all_imgs = make_dataset_folder(eval_folder)
          else:
              print('**********mode**********',mode)
              if mode =='train':
                  txt_files = cfg.DATASET.NEILIX_TRAIN_SET
              else:
                  txt_files = cfg.DATASET.NEILIX_VAL_SET
              res = []
              for f in txt_files:
                  res.extend(self.find_image_by_txt(cfg.DATASET.NEILIX_IMG_ROOT, f))
              self.all_imgs = res
          logx.msg('all {} imgs {}'.format(mode, len(self.all_imgs)))
          self.centroids = uniform.build_centroids(self.all_imgs,
                                                   self.num_classes,
                                                   self.train,
                                                   cv=cfg.DATASET.CV)
          self.build_epoch()

      def find_image_by_txt(self, img_root, txt_file):
          res = []
          with open(txt_file) as f:
              for line in f.readlines():
                  img, mask = line.strip().split(',')   #every line in txt is : data,label
                  img = img_root + img
                  mask = img_root + mask
                  # img = osp.join(img_root, img)
                  # mask = osp.join(img_root, mask)
                  res.append((img, mask))
          return res

      def fill_colormap_and_names(self, config_fn):
          """
          Mapillary code for color map and class names

          Outputs
          -------
          self.trainid_to_name
          self.color_mapping
          """
          with open(config_fn) as config_file:
              config = json.load(config_file)
          config_labels = config['labels']

          # calculate label color mapping
          colormap = []
          self.trainid_to_name = {}
          for i in range(0, len(config_labels)):
              colormap = colormap + config_labels[i]['color']
              name = config_labels[i]['readable']
              name = name.replace(' ', '_')
              self.trainid_to_name[i] = name
          self.color_mapping = colormap

      def fill_colormap(self):
          palette = [0, 0, 0,
                     1, 1, 1]         **# at first ,this is [0,0,255],[255,0,0] , after a couple of iters, loss is also negative** 
          zero_pad = 256 * 3 - len(palette)
          for i in range(zero_pad):
              palette.append(0)
          self.color_mapping = palette

`

Ahrenat commented 1 year ago

I also encountered the same problem. It seems that when gt has only one value, the RMI loss function calculates a negative value.

ajtao commented 1 year ago

https://github.com/ZJULearning/RMI/issues/2