jleuschn / learned_ct_reco_comparison_paper

Code and supplementing material for the Article "Quantitative comparison of deep learning-based image reconstruction methods for low-dose and sparse-angle CT applications"
13 stars 2 forks source link

could you provide the training file of lodopab dataset #4

Open listar0810 opened 2 years ago

listar0810 commented 2 years ago

Hi, Thx for your codes about these method.I have tried the MS-D-CNN using lodopab dataset Here is my training codes `import os import argparse import json try: from FBPMSDReconstructor import FBPMSDNetReconstructor MSD_PYTORCH_AVAILABLE = True except ImportError: MSD_PYTORCH_AVAILABLE = False import torch

IMPL = 'astra_cuda' RESULTS_PATH = '/data0/ct_logs/msdnet'

dataset = get_standard_dataset('lodopab', impl=IMPL) ray_trafo = dataset.get_ray_trafo(impl=IMPL) test_data = dataset.get_data_pairs('test', 100)

NOISE_SETTING_DEFAULT = 'gaussian_noise' NUM_ANGLES_DEFAULT = 50 METHOD_DEFAULT = 'fbpmsdnet'

parser = argparse.ArgumentParser() parser.add_argument('--noise_setting', type=str, default='gaussian_noise') parser.add_argument('--num_angles', type=int, default=50) parser.add_argument('--method', type=str, default='fbpmsdnet')

options = parser.parse_args()

noise_setting = options.noise_setting # 'gaussian_noise', 'scattering' num_angles = options.numangles # 50, 10, 5, 2 method = options.method # 'learnedpd', 'fbpunet', 'fbpmsdnet', 'cinn' name = 'lodopab{}_{}'.format(noise_setting, method)

from dival import get_standard_dataset from dival.measure import PSNR from dival.util.plot import plot_images import numpy as np IMPL = 'astra_cuda' datasets = get_standard_dataset('lodopab', impl=IMPL) test_data = datasets.get_data_pairs('test', 100)

FBP_DATASET_STATS = { 'noisefree': { 2: { 'mean_fbp': 0.0020300781237049294, 'std_fbp': 0.0036974098858769781, 'mean_gt': 0.0018248517968347585, 'std_gt': 0.0020251920919838714 }, 5: { 'mean_fbp': 0.0018914765285141003, 'std_fbp': 0.0027988724415204552, 'mean_gt': 0.0018248517968347585, 'std_gt': 0.0020251920919838714 }, 10: { 'mean_fbp': 0.0018791806499857538, 'std_fbp': 0.0023355593815585413, 'mean_gt': 0.0018248517968347585, 'std_gt': 0.0020251920919838714 }, 50: { 'mean_fbp': 0.0018856220845133943, 'std_fbp': 0.002038545754978578, 'mean_gt': 0.0018248517968347585, 'std_gt': 0.0020251920919838714 } }, 'gaussian_noise': { 2: { 'mean_fbp': 0.0020300515246877825, 'std_fbp': 0.01135122820016111, 'mean_gt': 0.0018248517968347585, 'std_gt': 0.0020251920919838714 }, 5: { 'mean_fbp': 0.0018914835384669934, 'std_fbp': 0.0073404856822226593, 'mean_gt': 0.0018248517968347585, 'std_gt': 0.0020251920919838714 }, 10: { 'mean_fbp': 0.0018791781748714272, 'std_fbp': 0.0053367740312729459, 'mean_gt': 0.0018248517968347585, 'std_gt': 0.0020251920919838714 }, 50: { 'mean_fbp': 0.0018856252771456445, 'std_fbp': 0.0029598508235758759, 'mean_gt': 0.0018248517968347585, 'std_gt': 0.0020251920919838714 } }, 'scattering': { 2: { 'mean_fbp': 0.68570249744436962, 'std_fbp': 1.3499668155231217, 'mean_gt': 0.002007653630624356, # different from gaussian_noise 'std_gt': 0.0019931366497635745 # since subset of slices is used }, 5: { 'mean_fbp': 0.67324839540841908, 'std_fbp': 0.99012416989800478, 'mean_gt': 0.002007653630624356, # different from gaussian_noise 'std_gt': 0.0019931366497635745 # since subset of slices is used }, 10: { 'mean_fbp': 0.66960775275347806, 'std_fbp': 0.80318946689776671, 'mean_gt': 0.002007653630624356, # different from gaussian_noise 'std_gt': 0.0019931366497635745 # since subset of slices is used }, 50: { 'mean_fbp': 0.67173917657611049, 'std_fbp': 0.6794825395874754, 'mean_gt': 0.002007653630624356, # different from gaussian_noise 'std_gt': 0.0019931366497635745 # since subset of slices is used } } }

ray_trafo = dataset.ray_trafo

assert MSD_PYTORCH_AVAILABLE reconstructor = FBPMSDNetReconstructor( ray_trafo, hyper_params={ 'depth': 100, 'width': 1, 'dilations': (1, 2, 3, 4, 5, 6, 7, 8, 9, 10), 'lr': 0.001, 'batch_size': 1, 'epochs': 50, 'data_augmentation': True, 'scheduler': 'none' }, save_best_learned_params_path=os.path.join(RESULTS_PATH, name), log_dir=os.path.join(RESULTS_PATH, name), num_data_loader_workers=0, )

reconstructor.save_hyper_params( os.path.join(RESULTS_PATH, '{}_hyper_params.json'.format(name)))

print("start training: '{}'".format(name)) print('hyper_params = {}'.format( json.dumps(reconstructor.hyper_params, indent=1))) reconstructor.train(dataset)

recos = [] psnrs = [] for obs, gt in test_data: reco = reconstructor.reconstruct(obs) recos.append(reco) psnrs.append(PSNR(reco, gt))

print('mean psnr: {:f}'.format(np.mean(psnrs))) import matplotlib.pyplot as plt for i in range(3): _, ax = plot_images([recos[i], test_data.ground_truth[i]], fig_size=(10, 4)) ax[0].set_xlabel('PSNR: {:.2f}'.format(psnrs[i])) ax[0].set_title('CINNReconstructor') ax[1].set_title('ground truth') ax[0].figure.suptitle('test sample {:d}'.format(i)) plt.show()`

but it raise error at this code line in 'odl_fourier_transform.py' 图片 and the error is below 图片 the original code about 'torch.fft.rfft(x_preproc,dim=-1)' is 'torch.rfft(x_preproc,1)' I modify it cause I use the pytorch1.9 version. is this line modifiation or my dataset loading cause the error ? and could u provide the ms_d_cnn network training file of lodopab dataset,many thx!

jleuschn commented 2 years ago

Hi,

sorry, I originally planned to integrate the MS-D-Net into the dival library, but never finished it, that's why the training script is missing. Here's a copy of the not-yet-pushed training script, you'll need to update some imports:

"""
Train FBPMSDNetReconstructor on 'lodopab'.
"""
import numpy as np
from dival import get_standard_dataset
from dival.measure import PSNR
from dival.reconstructors.fbpmsdnet_reconstructor import (
    FBPMSDNetReconstructor, compute_fbp_dataset_stats)
from dival.datasets.fbp_dataset import (
    generate_fbp_cache_files, get_cached_fbp_dataset)
from dival.reference_reconstructors import (
    check_for_params, download_params, get_hyper_params_path)
from dival.util.plot import plot_images

IMPL = 'astra_cuda'

LOG_DIR = './logs/lodopab_fbpmsdnet'
SAVE_BEST_LEARNED_PARAMS_PATH = './params/lodopab_fbpmsdnet'

CACHE_FILES = {
    'train':
        ('/localdata/dival_dataset_caches/cache_train_lodopab_fbp.npy', None),
    'validation':
        ('/localdata/dival_dataset_caches/cache_validation_lodopab_fbp.npy', None)}

dataset = get_standard_dataset('lodopab', impl=IMPL)
ray_trafo = dataset.get_ray_trafo(impl=IMPL)
test_data = dataset.get_data_pairs('test', 100)

reconstructor = FBPMSDNetReconstructor(
    ray_trafo, log_dir=LOG_DIR,
    save_best_learned_params_path=SAVE_BEST_LEARNED_PARAMS_PATH)

# #%% obtain reference hyper parameters
# if not check_for_params('fbpmsdnet', 'lodopab', include_learned=False):
#     download_params('fbpmsdnet', 'lodopab', include_learned=False)
# hyper_params_path = get_hyper_params_path('fbpmsdnet', 'lodopab')
# reconstructor.load_hyper_params(hyper_params_path)
hyper_params = {
    'depth': 200,
    'width': 1,
    'dilations': (1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
    'lr': 0.001,
    'batch_size': 1,
    'epochs': 15,
    'data_augmentation': False
}
reconstructor.hyper_params.update(hyper_params)
reconstructor.save_hyper_params(
    SAVE_BEST_LEARNED_PARAMS_PATH + '_hyper_params.json')

# number of parameters for depth 100, width 1, dilations 1...10: 45656
# number of parameters for depth 200, width 1, dilations 1...10: 181306
reconstructor.init_model()
print('number of parameters:',
      sum(t.numel() for t in reconstructor.model.state_dict().values()))

#%% expose FBP cache to reconstructor by assigning `fbp_dataset` attribute
# uncomment the next line to generate the cache files (~20 GB)
# generate_fbp_cache_files(dataset, ray_trafo, CACHE_FILES)
cached_fbp_dataset = get_cached_fbp_dataset(dataset, ray_trafo, CACHE_FILES)
dataset.fbp_dataset = cached_fbp_dataset

# stats = compute_fbp_dataset_stats(cached_fbp_dataset)
stats = {'mean_fbp': 0.17425122330415554,
         'std_fbp': 0.11476018693154483,
         'mean_gt': 0.16822296977667153,
         'std_gt': 0.11406864018442905}
dataset.fbp_dataset_stats = stats

#%% train
reconstructor.train(dataset)

#%% evaluate
recos = []
psnrs = []
for obs, gt in test_data:
    reco = reconstructor.reconstruct(obs)
    recos.append(reco)
    psnrs.append(PSNR(reco, gt))

print('mean psnr: {:f}'.format(np.mean(psnrs)))

for i in range(3):
    _, ax = plot_images([recos[i], test_data.ground_truth[i]],
                        fig_size=(10, 4))
    ax[0].set_xlabel('PSNR: {:.2f}'.format(psnrs[i]))
    ax[0].set_title('FBPMSDNetReconstructor')
    ax[1].set_title('ground truth')
    ax[0].figure.suptitle('test sample {:d}'.format(i))

Regarding the fft error, I think the change is causing the error; torch now supports complex dtypes, and their new fft functions now use it, while the code in this repository expects the old format (f[..., 0] is real part, f[..., 1] is imaginary part). You would need to use the old version of PyTorch, or adapt the code to work with the new complex-dtype fft functions (shouldn't be that much that needs to be changed).

Best regards

jleuschn commented 2 years ago

And here's the reconstructor class from the unpushed code:

from warnings import warn
from copy import deepcopy

import torch
import numpy as np
import torch.nn as nn
from odl.tomo import fbp_op
from tqdm import tqdm

from dival.reconstructors.standard_learned_reconstructor import (
    StandardLearnedReconstructor)
from dival.reconstructors.networks.msdnet import MSDNet
from dival.datasets.fbp_dataset import FBPDataset

def compute_fbp_dataset_stats(fbp_dataset):
    """
    Compute means and standard deviations for the elements of an FBP dataset.
    Only the ``'train'`` part is used.
    """
    # Adapted from: https://github.com/ahendriksen/msd_pytorch/blob/162823c502701f5eedf1abcd56e137f8447a72ef/msd_pytorch/msd_model.py#L95
    mean_fbp = 0.
    mean_gt = 0.
    square_fbp = 0.
    square_gt = 0.
    n = fbp_dataset.get_len('train')
    for fbp, gt in tqdm(fbp_dataset.generator('train'), total=n,
                        desc='computing fbp dataset stats'):
        mean_fbp += np.mean(fbp)
        mean_gt += np.mean(gt)
        square_fbp += np.mean(np.square(fbp))
        square_gt += np.mean(np.square(gt))
    mean_fbp /= n
    mean_gt /= n
    square_fbp /= n
    square_gt /= n
    std_fbp = np.sqrt(square_fbp - mean_fbp**2)
    std_gt = np.sqrt(square_gt - mean_gt**2)
    stats = {'mean_fbp': mean_fbp,
             'std_fbp': std_fbp,
             'mean_gt': mean_gt,
             'std_gt': std_gt}
    return stats

class FBPMSDNetReconstructor(StandardLearnedReconstructor):
    """
    CT reconstructor applying filtered back-projection followed by a
    postprocessing U-Net (e.g. [1]_).

    References
    ----------
    .. [1] K. H. Jin, M. T. McCann, E. Froustey, et al., 2017,
           "Deep Convolutional Neural Network for Inverse Problems in Imaging".
           IEEE Transactions on Image Processing.
           `doi:10.1109/TIP.2017.2713099
           <https://doi.org/10.1109/TIP.2017.2713099>`_
    """

    HYPER_PARAMS = deepcopy(StandardLearnedReconstructor.HYPER_PARAMS)
    HYPER_PARAMS.update({
        'depth': {
            'default': 100,
            'retrain': True
        },
        'width': {
            'default': 1,
            'retrain': True
        },
        'dilations': {
            'default': (1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
            'retrain': True
        },
        'filter_type': {
            'default': 'Hann',
            'retrain': True
        },
        'frequency_scaling': {
            'default': 1.0,
            'retrain': True
        },
        'lr': {
            'default': 0.001,
            'retrain': True
        },
        'scheduler': {
            'default': 'none',
            'choices': ['none', 'base', 'cosine'],  # 'base': inherit
            'retrain': True
        },
        'lr_min': {  # only used if 'cosine' scheduler is selected
            'default': 1e-4,
            'retrain': True
        },
        'data_augmentation': {
            'default': True,
            'retrain': True
        }
    })

    def __init__(self, ray_trafo, **kwargs):
        """
        Parameters
        ----------
        ray_trafo : :class:`odl.tomo.RayTransform`
            Ray transform (the forward operator).

        Further keyword arguments are passed to ``super().__init__()``.
        """
        self._fbp_dataset_stats = None
        super().__init__(ray_trafo, **kwargs)

    def train(self, dataset):
        try:
            fbp_dataset = dataset.fbp_dataset
        except AttributeError:
            warn('Training FBPMSDNetReconstructor with no cached FBP dataset. '
                 'Will compute the FBPs on the fly. For faster training, '
                 'consider precomputing the FBPs with '
                 '`generate_fbp_cache_files(...)` and pass them to `train()` '
                 'by setting the attribute '
                 '``dataset.fbp_dataset = get_cached_fbp_dataset(...)``.')
            fbp_dataset = FBPDataset(
                dataset, self.non_normed_op, filter_type=self.filter_type,
                frequency_scaling=self.frequency_scaling)

        try:
            self._fbp_dataset_stats = dataset.fbp_dataset_stats
        except AttributeError:
            print('Computing statistics of FBP dataset in '
                  '`FBPMSDNetReconstructor.train()`. To avoid recomputing '
                  'them for each training, consider passing them to `train()` '
                  'by setting the attribute '
                  "``dataset.fbp_dataset_stats = {"
                  "    'mean_fbp': ..., "
                  "    'std_fbp': ..., "
                  "    'mean_gt': ..., "
                  "    'std_gt': ...}``."
                  'This dict can be computed using '
                  '``compute_fbp_dataset_stats(fbp_dataset)``.')
            self._fbp_dataset_stats = compute_fbp_dataset_stats(fbp_dataset)

        super().train(fbp_dataset)

        self._fbp_dataset_stats = None  # reset, because the only purpose is to
                                        # expose the stats to self.init_model()

    def init_transform(self, dataset):
        if self.data_augmentation:
            def random_flip_rotate_transform(sample):
                fbp, gt = sample
                choice = torch.randint(8, (1,))[0]
                if choice % 4 == 1:
                    fbp = torch.flip(fbp, (1,))
                    gt = torch.flip(gt, (1,))
                elif choice % 4 == 2:
                    fbp = torch.flip(fbp, (2,))
                    gt = torch.flip(gt, (2,))
                elif choice % 4 == 3:
                    fbp = torch.flip(fbp, (1, 2))
                    gt = torch.flip(gt, (1, 2))
                if choice // 4 == 1:
                    fbp = torch.transpose(fbp, 1, 2)
                    gt = torch.transpose(gt, 1, 2)
                return fbp, gt
            self._transform = random_flip_rotate_transform
        else:
            self._transform = None

    def init_model(self):
        self.fbp_op = fbp_op(self.op, filter_type=self.filter_type,
                             frequency_scaling=self.frequency_scaling)
        self.model = MSDNet(in_ch=1, out_ch=1, depth=self.depth,
                            width=self.width, dilations=self.dilations)
        if self._fbp_dataset_stats is not None:
            self.model.set_normalization(
                mean_in=self._fbp_dataset_stats['mean_fbp'],
                std_in=self._fbp_dataset_stats['std_fbp'],
                mean_out=self._fbp_dataset_stats['mean_gt'],
                std_out=self._fbp_dataset_stats['std_gt'])
        # if self.init_bias_zero:
        #     def weights_init(m):
        #         if isinstance(m, torch.nn.Conv2d):
        #             m.bias.data.fill_(0.0)
        #     self.model.apply(weights_init)

        if self.use_cuda:
            self.model = nn.DataParallel(self.model).to(self.device)

    def init_optimizer(self, dataset_train):
        """
        Initialize the optimizer.
        Called in :meth:`train`, after calling :meth:`init_model` and before
        calling :meth:`init_scheduler`.

        Parameters
        ----------
        dataset_train : :class:`torch.utils.data.Dataset`
            The training (torch) dataset constructed in :meth:`train`.
        """
        # only train msd, but not scale_in and scale_out
        parameters = [p for p in self.model.parameters() if p.requires_grad]
        self.optimizer = torch.optim.Adam(parameters, lr=self.lr)

    def init_scheduler(self, dataset_train):
        # need to set private self._scheduler because self.scheduler
        # property accesses hyper parameter of same name,
        # i.e. self.hyper_params['scheduler']
        if self.scheduler.lower() == 'none':
            self._scheduler = None
        elif self.scheduler.lower() == 'cosine':
            self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                self.optimizer,
                T_max=self.epochs,
                eta_min=self.lr_min)
        else:
            super().init_scheduler(dataset_train)

    def _reconstruct(self, observation):
        self.model.eval()
        fbp = self.fbp_op(observation)
        fbp_tensor = torch.from_numpy(
            np.asarray(fbp)[None, None]).to(self.device)
        reco_tensor = self.model(fbp_tensor)
        reconstruction = reco_tensor.cpu().detach().numpy()[0, 0]
        return self.reco_space.element(reconstruction)