xuebinqin / U-2-Net

The code for our newly accepted paper in Pattern Recognition 2020: "U^2-Net: Going Deeper with Nested U-Structure for Salient Object Detection."
Apache License 2.0
8.67k stars 1.5k forks source link

Inference speed #359

Open cindymuji opened 1 year ago

cindymuji commented 1 year ago

After actual testing, the main reason for the slow speed is in the 'RescaleT‘ method. replace with this:

 def __call__(self,sample):
        imidx, image, label = sample['imidx'], sample['image'], sample['label']
        # h, w = image.shape[:2]
        # if isinstance(self.output_size,int):
        #     if h > w:
        #         new_h, new_w = self.output_size*h/w,self.output_size
        #     else:
        #         new_h, new_w = self.output_size,self.output_size*w/h
        # else:
        #     new_h, new_w = self.output_size
        # new_h, new_w = int(new_h), int(new_w)
        label = np.squeeze(label).astype(np.uint8)
        image = Image.fromarray(image)
        label = Image.fromarray(label,mode='L')

        image = image.resize((self.output_size, self.output_size), resample=Image.BILINEAR)
        label = label.resize((self.output_size, self.output_size), resample=Image.NEAREST)  

        label = np.array(label)
        label = np.expand_dims(label, axis=0)  
        image = np.array(image)

        # img = transform.resize(image,(self.output_size,self.output_size),mode='constant')
        # lbl = transform.resize(label,(self.output_size,self.output_size),mode='constant', order=0, preserve_range=True)
        return {'imidx':imidx, 'image':image,'label':label}

But will this cause accuracy issues?

cindymuji commented 1 year ago

transform.resize(image,(self.output_size,self.output_size),mode='constant') the input and output of the transform.resize function are in the format of width × height × channels (w, h, c). BUT in provided code, the image shape is (h, w, c).