TaoHuang2018 / Neighbor2Neighbor

Neighbor2Neighbor: Self-Supervised Denoising from Single Noisy Images
BSD 3-Clause "New" or "Revised" License
244 stars 37 forks source link

Would you plan to release your train code on raw image dataset(SIDD)? #7

Closed madfff closed 2 years ago

madfff commented 2 years ago

Thanks for your great work. I apply neighbor sub-sampler on the packed 4-channel raw images and cannot reproduce the results. Did I do anything wrong? Here is my data-processing code.

NOISY_PATH = ['_NOISY_RAW_010.MAT','_NOISY_RAW_011.MAT']

MODEL_BAYER = {'GP':'BGGR','IP':'RGGB','S6':'GRBG','N6':'BGGR','G4':'BGGR'}

TARGET_PATTERN = 'RGGB'

class Dataset(data.Dataset):
    def __init__(self, path, crop_size, is_train=True):
        super(Dataset, self).__init__()
        self.crop_size = crop_size
        self.is_train = is_train

        if self.is_train:
            self.unify_mode = 'crop'
        else:
            self.unify_mode = 'pad'
        self.file_lists = []
        folder_names = os.listdir(os.path.join(path,'Data'))

        for folder_name in folder_names:
            scene_instance_number,scene_number,_ = meta_read(folder_name)
            for file_path in NOISY_PATH:
                self.file_lists.append(os.path.join(path, 'Data', folder_name, scene_instance_number+file_path))

    def __getitem__(self, index):
        noisy_path = self.file_lists[index]
        gt_path = noisy_path.replace('NOISY', 'GT')
        _,_,bayer_pattern = meta_read(noisy_path.split('/')[-2])

        noisy = h5py_loadmat(noisy_path)
        noisy = BayerUnifyAug.bayer_unify(noisy, bayer_pattern, TARGET_PATTERN, self.unify_mode)

        gt = h5py_loadmat(gt_path)
        gt = BayerUnifyAug.bayer_unify(gt, bayer_pattern, TARGET_PATTERN, self.unify_mode)

        if self.is_train:
            augment = np.random.rand(3) > 0.5
            noisy = BayerUnifyAug.bayer_aug(noisy, augment[0], augment[1], augment[2], TARGET_PATTERN)
            gt = BayerUnifyAug.bayer_aug(gt, augment[0], augment[1], augment[2], TARGET_PATTERN)

        noisy = pack_raw_np(noisy[:,:,None])
        gt = pack_raw_np(gt[:,:,None])

        if self.crop_size[0] != 0 and self.crop_size[1] != 0:
            H, W, _ = noisy.shape
            rnd_h = random.randint(0, max(0, H - self.crop_size[0]))
            rnd_w = random.randint(0, max(0, W - self.crop_size[1]))

            noisy = noisy[rnd_h:rnd_h + self.crop_size[0], rnd_w:rnd_w + self.crop_size[1], :]
            gt = gt[rnd_h:rnd_h + self.crop_size[0], rnd_w:rnd_w + self.crop_size[1], :]

        noisy = torch.from_numpy(noisy.transpose(2, 0, 1))
        gt = torch.from_numpy(gt.transpose(2, 0, 1))
        return noisy, gt, bayer_pattern

    def __len__(self):
        return len(self.file_lists)

def meta_read(info):
    info = info.split('_')
    scene_instance_number       = info[0]
    scene_number                = info[1]
    smartphone_code             = info[2]
    #ISO_level                   = info[3]
    #shutter_speed               = info[4]
    #illuminant_temperature      = info[5]
    #illuminant_brightness_code  = info[6]

    return scene_instance_number,scene_number,MODEL_BAYER[smartphone_code]

def pack_raw_np(im):
    img_shape = im.shape
    H = img_shape[0]
    W = img_shape[1]
    ## R G G B
    out = np.concatenate((im[0:H:2,0:W:2,:], 
                       im[0:H:2,1:W:2,:],
                       im[1:H:2,0:W:2,:],
                       im[1:H:2,1:W:2,:]), axis=2)
    return out

def h5py_loadmat(file_path:str):
    with h5py.File(file_path, 'r') as f:
        return np.array(f.get('x'),dtype=np.float32)
TaoHuang2018 commented 2 years ago

I am sorry for my late reply. For the training of Neighbor2Neighbor in the SIDD dataset:

  1. The code for data preparation is in the newly uploaded file dataset_tool_raw.py. It would be better to consider the camera device difference in the SIDD Medium Set. However, as the device info is not provided in the SIDD validation set, we simply ignore the device difference for training set preparation.
  2. The training code for the SIDD dataset is similar to that for the ImageNet validation set, except some operations for raw images. To mention, the training scheme for the SIDD dataset is also similar, i.e., the number of epoch is 100 and the learning rate decay schedule is the same, except we use a smaller $gamma=1$. Here is some code for processing raw images.
    
    def space_to_depth(x, block_size):
    n, c, h, w = x.size()
    unfolded_x = torch.nn.functional.unfold(x, block_size, stride=block_size)
    return unfolded_x.view(n, c * block_size**2, h // block_size,
                           w // block_size)

def depth_to_space(x, block_size): return torch.nn.functional.pixel_shuffle(x, block_size)

def generate_mask_pair(img):

prepare masks (N x C x H/2 x W/2)

n, c, h, w = img.shape
mask1 = torch.zeros(size=(n * h // 2 * w // 2 * 4, ),
                    dtype=torch.bool,
                    device=img.device)
mask2 = torch.zeros(size=(n * h // 2 * w // 2 * 4, ),
                    dtype=torch.bool,
                    device=img.device)
# prepare random mask pairs
idx_pair = torch.tensor(
    [[0, 1], [0, 2], [1, 3], [2, 3], [1, 0], [2, 0], [3, 1], [3, 2]],
    dtype=torch.int64,
    device=img.device)
rd_idx = torch.zeros(size=(n * h // 2 * w // 2, ),
                     dtype=torch.int64,
                     device=img.device)
torch.randint(low=0,
              high=8,
              size=(n * h // 2 * w // 2, ),
              generator=get_generator(),
              out=rd_idx)
rd_pair_idx = idx_pair[rd_idx]
rd_pair_idx += torch.arange(start=0,
                            end=n * h // 2 * w // 2 * 4,
                            step=4,
                            dtype=torch.int64,
                            device=img.device).reshape(-1, 1)
# get masks
mask1[rd_pair_idx[:, 0]] = 1
mask2[rd_pair_idx[:, 1]] = 1
return mask1, mask2

def generate_subimages(img, mask): n, c, h, w = img.shape subimage = torch.zeros(n, c, h // 2, w // 2, dtype=img.dtype, layout=img.layout, device=img.device)

per channel

for i in range(c):
    img_per_channel = space_to_depth(img[:, i:i + 1, :, :], block_size=2)
    img_per_channel = img_per_channel.permute(0, 2, 3, 1).reshape(-1)
    subimage[:, i:i + 1, :, :] = img_per_channel[mask].reshape(
        n, h // 2, w // 2, 1).permute(0, 3, 1, 2)
return subimage

class DataLoader_SIDD_Medium_Raw(data.Dataset): def init(self, data_dir): super(DataLoader_SIDD_Medium_Raw, self).init() self.data_dir = data_dir

get images path

    self.train_fns = glob.glob(os.path.join(self.data_dir, "*"))
    self.train_fns.sort()
    print('fetch {} samples for training'.format(len(self.train_fns)))
def __getitem__(self, index):
    # fetch image
    fn = self.train_fns[index]
    im = loadmat(fn)["x"]
    im = im[np.newaxis, :, :]
    im = torch.from_numpy(im)
    return im
def __len__(self):
    return len(self.train_fns)

def get_SIDD_validation(dataset_dir): val_data_dict = loadmat( os.path.join(dataset_dir, "ValidationNoisyBlocksRaw.mat")) val_data_noisy = val_data_dict['ValidationNoisyBlocksRaw'] val_data_dict = loadmat( os.path.join(dataset_dir, 'ValidationGtBlocksRaw.mat')) val_data_gt = val_data_dict['ValidationGtBlocksRaw'] num_img, numblock, , _ = val_data_gt.shape return num_img, num_block, val_data_noisy, val_data_gt

madfff commented 2 years ago

I still hava a little question. Is it right to generate subimages on the packed 4-channel raw images? Thank you for your patience.

TaoHuang2018 commented 2 years ago

yes, on the packed 4-channel raw images.

zejinwang commented 2 years ago

I still hava a little question. Is it right to generate subimages on the packed 4-channel raw images? Thank you for your patience.

Hello, I have also encountered a problem that cannot be reproduced on SIDD raw-RGB as high as 51.06dB. I implemented it directly on the source code provided by the author, and psnr can only reach 46.7dB. This is my code:

for epoch in range(1, opt.n_epoch + 1): cnt = 0

for param_group in optimizer.param_groups:
    current_lr = param_group['lr']
print("LearningRate of Epoch {} = {}".format(epoch, current_lr))

network.train()
for iteration, noisy in enumerate(TrainingLoader):
    st = time.time()
    noisy = noisy.cuda()
    # pack raw data
    noisy = space_to_depth(noisy, 2)

    optimizer.zero_grad()

    mask1, mask2 = generate_mask_pair(noisy)
    noisy_sub1 = generate_subimages(noisy, mask1)
    noisy_sub2 = generate_subimages(noisy, mask2)
    with torch.no_grad():
        noisy_denoised = network(noisy)
    noisy_sub1_denoised = generate_subimages(noisy_denoised, mask1)
    noisy_sub2_denoised = generate_subimages(noisy_denoised, mask2)

    noisy_output = network(noisy_sub1)
    noisy_target = noisy_sub2
    Lambda = epoch / opt.n_epoch * opt.increase_ratio
    diff = noisy_output - noisy_target
    exp_diff = noisy_sub1_denoised - noisy_sub2_denoised

    loss1 = torch.mean(diff**2)
    loss2 = Lambda * torch.mean((diff - exp_diff)**2)
    loss_all = opt.Lambda1 * loss1 + opt.Lambda2 * loss2

    loss_all.backward()
    optimizer.step()
    print(
        '{:04d} {:05d} Loss1={:.6f}, Lambda={}, Loss2={:.6f}, Loss_Full={:.6f}, Time={:.4f}'
        .format(epoch, iteration, np.mean(loss1.item()), Lambda,
                np.mean(loss2.item()), np.mean(loss_all.item()),
                time.time() - st))

scheduler.step()

if epoch % opt.n_snapshot == 0 or epoch == opt.n_epoch:
    network.eval()
    # save checkpoint
    checkpoint(network, epoch, "model")
    # validation
    save_model_path = os.path.join(opt.save_model_path, opt.log_name,
                                   systime)
    validation_path = os.path.join(save_model_path, "validation")
    os.makedirs(validation_path, exist_ok=True)
    np.random.seed(101)

    for valid_name, valid_data in valid_dict.items():
        psnr_result = []
        ssim_result = []
        num_img, num_block, valid_noisy, valid_gt = valid_data
        for idx in range(num_img):
            for idy in range(num_block):
                im = valid_gt[idx, idy][:, :, np.newaxis]
                noisy_im = valid_noisy[idx, idy][:, :, np.newaxis]

                origin255 = im.copy() * 255.0
                origin255 = origin255.astype(np.uint8)
                noisy255 = noisy_im.copy() * 255.0
                noisy255 = noisy255.astype(np.uint8)
                # padding to square
                H = noisy_im.shape[0]
                W = noisy_im.shape[1]
                val_size = (max(H, W) + 31) // 32 * 32
                noisy_im = np.pad(
                    noisy_im,
                    [[0, val_size - H], [0, val_size - W], [0, 0]],
                    'reflect')
                transformer = transforms.Compose([transforms.ToTensor()])
                noisy_im = transformer(noisy_im)
                noisy_im = torch.unsqueeze(noisy_im, 0)
                noisy_im = noisy_im.cuda()
                # pack raw data
                noisy_im = space_to_depth(noisy_im, block_size=2)
                with torch.no_grad():
                    prediction = network(noisy_im)
                    # unpack raw data
                    prediction = depth_to_space(prediction, block_size=2)
                    prediction = prediction[:, :, :H, :W]
                prediction = prediction.permute(0, 2, 3, 1)
                prediction = prediction.cpu().data.clamp(0, 1).numpy()
                prediction = prediction.squeeze(0)
                pred255 = np.clip(prediction * 255.0 + 0.5, 0,
                                    255).astype(np.uint8)
                # calculate psnr
                cur_psnr = calculate_psnr(origin255.astype(np.float32),
                                            pred255.astype(np.float32))
                psnr_result.append(cur_psnr)
                cur_ssim = calculate_ssim(origin255.astype(np.float32),
                                            pred255.astype(np.float32))
                ssim_result.append(cur_ssim)

                # visualization
                save_path = os.path.join(
                    validation_path,
                    "{}_{:03d}-{:03d}-{:03d}_clean.png".format(
                        valid_name, idx, idy, epoch))
                Image.fromarray(origin255.squeeze()).save(save_path)
                save_path = os.path.join(
                    validation_path,
                    "{}_{:03d}-{:03d}-{:03d}_noisy.png".format(
                        valid_name, idx, idy, epoch))
                Image.fromarray(noisy255.squeeze()).save(save_path)

                save_path = os.path.join(
                    validation_path,
                    "{}_{:03d}-{:03d}-{:03d}_denoised.png".format(
                        valid_name, idx, idy, epoch))
                Image.fromarray(pred255.squeeze()).save(save_path)

        psnr_result = np.array(psnr_result)
        avg_psnr = np.mean(psnr_result)
        avg_ssim = np.mean(ssim_result)
        log_path = os.path.join(validation_path,
                                "A_log_{}.csv".format(valid_name))
        with open(log_path, "a") as f:
            f.writelines("{},{},{}\n".format(epoch, avg_psnr, avg_ssim))