Open listar0810 opened 2 years ago
Hi,
sorry, I originally planned to integrate the MS-D-Net into the dival library, but never finished it, that's why the training script is missing. Here's a copy of the not-yet-pushed training script, you'll need to update some imports:
"""
Train FBPMSDNetReconstructor on 'lodopab'.
"""
import numpy as np
from dival import get_standard_dataset
from dival.measure import PSNR
from dival.reconstructors.fbpmsdnet_reconstructor import (
FBPMSDNetReconstructor, compute_fbp_dataset_stats)
from dival.datasets.fbp_dataset import (
generate_fbp_cache_files, get_cached_fbp_dataset)
from dival.reference_reconstructors import (
check_for_params, download_params, get_hyper_params_path)
from dival.util.plot import plot_images
IMPL = 'astra_cuda'
LOG_DIR = './logs/lodopab_fbpmsdnet'
SAVE_BEST_LEARNED_PARAMS_PATH = './params/lodopab_fbpmsdnet'
CACHE_FILES = {
'train':
('/localdata/dival_dataset_caches/cache_train_lodopab_fbp.npy', None),
'validation':
('/localdata/dival_dataset_caches/cache_validation_lodopab_fbp.npy', None)}
dataset = get_standard_dataset('lodopab', impl=IMPL)
ray_trafo = dataset.get_ray_trafo(impl=IMPL)
test_data = dataset.get_data_pairs('test', 100)
reconstructor = FBPMSDNetReconstructor(
ray_trafo, log_dir=LOG_DIR,
save_best_learned_params_path=SAVE_BEST_LEARNED_PARAMS_PATH)
# #%% obtain reference hyper parameters
# if not check_for_params('fbpmsdnet', 'lodopab', include_learned=False):
# download_params('fbpmsdnet', 'lodopab', include_learned=False)
# hyper_params_path = get_hyper_params_path('fbpmsdnet', 'lodopab')
# reconstructor.load_hyper_params(hyper_params_path)
hyper_params = {
'depth': 200,
'width': 1,
'dilations': (1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
'lr': 0.001,
'batch_size': 1,
'epochs': 15,
'data_augmentation': False
}
reconstructor.hyper_params.update(hyper_params)
reconstructor.save_hyper_params(
SAVE_BEST_LEARNED_PARAMS_PATH + '_hyper_params.json')
# number of parameters for depth 100, width 1, dilations 1...10: 45656
# number of parameters for depth 200, width 1, dilations 1...10: 181306
reconstructor.init_model()
print('number of parameters:',
sum(t.numel() for t in reconstructor.model.state_dict().values()))
#%% expose FBP cache to reconstructor by assigning `fbp_dataset` attribute
# uncomment the next line to generate the cache files (~20 GB)
# generate_fbp_cache_files(dataset, ray_trafo, CACHE_FILES)
cached_fbp_dataset = get_cached_fbp_dataset(dataset, ray_trafo, CACHE_FILES)
dataset.fbp_dataset = cached_fbp_dataset
# stats = compute_fbp_dataset_stats(cached_fbp_dataset)
stats = {'mean_fbp': 0.17425122330415554,
'std_fbp': 0.11476018693154483,
'mean_gt': 0.16822296977667153,
'std_gt': 0.11406864018442905}
dataset.fbp_dataset_stats = stats
#%% train
reconstructor.train(dataset)
#%% evaluate
recos = []
psnrs = []
for obs, gt in test_data:
reco = reconstructor.reconstruct(obs)
recos.append(reco)
psnrs.append(PSNR(reco, gt))
print('mean psnr: {:f}'.format(np.mean(psnrs)))
for i in range(3):
_, ax = plot_images([recos[i], test_data.ground_truth[i]],
fig_size=(10, 4))
ax[0].set_xlabel('PSNR: {:.2f}'.format(psnrs[i]))
ax[0].set_title('FBPMSDNetReconstructor')
ax[1].set_title('ground truth')
ax[0].figure.suptitle('test sample {:d}'.format(i))
Regarding the fft error, I think the change is causing the error; torch now supports complex dtypes, and their new fft functions now use it, while the code in this repository expects the old format (f[..., 0]
is real part, f[..., 1]
is imaginary part). You would need to use the old version of PyTorch, or adapt the code to work with the new complex-dtype fft functions (shouldn't be that much that needs to be changed).
Best regards
And here's the reconstructor class from the unpushed code:
from warnings import warn
from copy import deepcopy
import torch
import numpy as np
import torch.nn as nn
from odl.tomo import fbp_op
from tqdm import tqdm
from dival.reconstructors.standard_learned_reconstructor import (
StandardLearnedReconstructor)
from dival.reconstructors.networks.msdnet import MSDNet
from dival.datasets.fbp_dataset import FBPDataset
def compute_fbp_dataset_stats(fbp_dataset):
"""
Compute means and standard deviations for the elements of an FBP dataset.
Only the ``'train'`` part is used.
"""
# Adapted from: https://github.com/ahendriksen/msd_pytorch/blob/162823c502701f5eedf1abcd56e137f8447a72ef/msd_pytorch/msd_model.py#L95
mean_fbp = 0.
mean_gt = 0.
square_fbp = 0.
square_gt = 0.
n = fbp_dataset.get_len('train')
for fbp, gt in tqdm(fbp_dataset.generator('train'), total=n,
desc='computing fbp dataset stats'):
mean_fbp += np.mean(fbp)
mean_gt += np.mean(gt)
square_fbp += np.mean(np.square(fbp))
square_gt += np.mean(np.square(gt))
mean_fbp /= n
mean_gt /= n
square_fbp /= n
square_gt /= n
std_fbp = np.sqrt(square_fbp - mean_fbp**2)
std_gt = np.sqrt(square_gt - mean_gt**2)
stats = {'mean_fbp': mean_fbp,
'std_fbp': std_fbp,
'mean_gt': mean_gt,
'std_gt': std_gt}
return stats
class FBPMSDNetReconstructor(StandardLearnedReconstructor):
"""
CT reconstructor applying filtered back-projection followed by a
postprocessing U-Net (e.g. [1]_).
References
----------
.. [1] K. H. Jin, M. T. McCann, E. Froustey, et al., 2017,
"Deep Convolutional Neural Network for Inverse Problems in Imaging".
IEEE Transactions on Image Processing.
`doi:10.1109/TIP.2017.2713099
<https://doi.org/10.1109/TIP.2017.2713099>`_
"""
HYPER_PARAMS = deepcopy(StandardLearnedReconstructor.HYPER_PARAMS)
HYPER_PARAMS.update({
'depth': {
'default': 100,
'retrain': True
},
'width': {
'default': 1,
'retrain': True
},
'dilations': {
'default': (1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
'retrain': True
},
'filter_type': {
'default': 'Hann',
'retrain': True
},
'frequency_scaling': {
'default': 1.0,
'retrain': True
},
'lr': {
'default': 0.001,
'retrain': True
},
'scheduler': {
'default': 'none',
'choices': ['none', 'base', 'cosine'], # 'base': inherit
'retrain': True
},
'lr_min': { # only used if 'cosine' scheduler is selected
'default': 1e-4,
'retrain': True
},
'data_augmentation': {
'default': True,
'retrain': True
}
})
def __init__(self, ray_trafo, **kwargs):
"""
Parameters
----------
ray_trafo : :class:`odl.tomo.RayTransform`
Ray transform (the forward operator).
Further keyword arguments are passed to ``super().__init__()``.
"""
self._fbp_dataset_stats = None
super().__init__(ray_trafo, **kwargs)
def train(self, dataset):
try:
fbp_dataset = dataset.fbp_dataset
except AttributeError:
warn('Training FBPMSDNetReconstructor with no cached FBP dataset. '
'Will compute the FBPs on the fly. For faster training, '
'consider precomputing the FBPs with '
'`generate_fbp_cache_files(...)` and pass them to `train()` '
'by setting the attribute '
'``dataset.fbp_dataset = get_cached_fbp_dataset(...)``.')
fbp_dataset = FBPDataset(
dataset, self.non_normed_op, filter_type=self.filter_type,
frequency_scaling=self.frequency_scaling)
try:
self._fbp_dataset_stats = dataset.fbp_dataset_stats
except AttributeError:
print('Computing statistics of FBP dataset in '
'`FBPMSDNetReconstructor.train()`. To avoid recomputing '
'them for each training, consider passing them to `train()` '
'by setting the attribute '
"``dataset.fbp_dataset_stats = {"
" 'mean_fbp': ..., "
" 'std_fbp': ..., "
" 'mean_gt': ..., "
" 'std_gt': ...}``."
'This dict can be computed using '
'``compute_fbp_dataset_stats(fbp_dataset)``.')
self._fbp_dataset_stats = compute_fbp_dataset_stats(fbp_dataset)
super().train(fbp_dataset)
self._fbp_dataset_stats = None # reset, because the only purpose is to
# expose the stats to self.init_model()
def init_transform(self, dataset):
if self.data_augmentation:
def random_flip_rotate_transform(sample):
fbp, gt = sample
choice = torch.randint(8, (1,))[0]
if choice % 4 == 1:
fbp = torch.flip(fbp, (1,))
gt = torch.flip(gt, (1,))
elif choice % 4 == 2:
fbp = torch.flip(fbp, (2,))
gt = torch.flip(gt, (2,))
elif choice % 4 == 3:
fbp = torch.flip(fbp, (1, 2))
gt = torch.flip(gt, (1, 2))
if choice // 4 == 1:
fbp = torch.transpose(fbp, 1, 2)
gt = torch.transpose(gt, 1, 2)
return fbp, gt
self._transform = random_flip_rotate_transform
else:
self._transform = None
def init_model(self):
self.fbp_op = fbp_op(self.op, filter_type=self.filter_type,
frequency_scaling=self.frequency_scaling)
self.model = MSDNet(in_ch=1, out_ch=1, depth=self.depth,
width=self.width, dilations=self.dilations)
if self._fbp_dataset_stats is not None:
self.model.set_normalization(
mean_in=self._fbp_dataset_stats['mean_fbp'],
std_in=self._fbp_dataset_stats['std_fbp'],
mean_out=self._fbp_dataset_stats['mean_gt'],
std_out=self._fbp_dataset_stats['std_gt'])
# if self.init_bias_zero:
# def weights_init(m):
# if isinstance(m, torch.nn.Conv2d):
# m.bias.data.fill_(0.0)
# self.model.apply(weights_init)
if self.use_cuda:
self.model = nn.DataParallel(self.model).to(self.device)
def init_optimizer(self, dataset_train):
"""
Initialize the optimizer.
Called in :meth:`train`, after calling :meth:`init_model` and before
calling :meth:`init_scheduler`.
Parameters
----------
dataset_train : :class:`torch.utils.data.Dataset`
The training (torch) dataset constructed in :meth:`train`.
"""
# only train msd, but not scale_in and scale_out
parameters = [p for p in self.model.parameters() if p.requires_grad]
self.optimizer = torch.optim.Adam(parameters, lr=self.lr)
def init_scheduler(self, dataset_train):
# need to set private self._scheduler because self.scheduler
# property accesses hyper parameter of same name,
# i.e. self.hyper_params['scheduler']
if self.scheduler.lower() == 'none':
self._scheduler = None
elif self.scheduler.lower() == 'cosine':
self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.optimizer,
T_max=self.epochs,
eta_min=self.lr_min)
else:
super().init_scheduler(dataset_train)
def _reconstruct(self, observation):
self.model.eval()
fbp = self.fbp_op(observation)
fbp_tensor = torch.from_numpy(
np.asarray(fbp)[None, None]).to(self.device)
reco_tensor = self.model(fbp_tensor)
reconstruction = reco_tensor.cpu().detach().numpy()[0, 0]
return self.reco_space.element(reconstruction)
Hi, Thx for your codes about these method.I have tried the MS-D-CNN using lodopab dataset Here is my training codes `import os import argparse import json try: from FBPMSDReconstructor import FBPMSDNetReconstructor MSD_PYTORCH_AVAILABLE = True except ImportError: MSD_PYTORCH_AVAILABLE = False import torch
IMPL = 'astra_cuda' RESULTS_PATH = '/data0/ct_logs/msdnet'
dataset = get_standard_dataset('lodopab', impl=IMPL) ray_trafo = dataset.get_ray_trafo(impl=IMPL) test_data = dataset.get_data_pairs('test', 100)
NOISE_SETTING_DEFAULT = 'gaussian_noise' NUM_ANGLES_DEFAULT = 50 METHOD_DEFAULT = 'fbpmsdnet'
parser = argparse.ArgumentParser() parser.add_argument('--noise_setting', type=str, default='gaussian_noise') parser.add_argument('--num_angles', type=int, default=50) parser.add_argument('--method', type=str, default='fbpmsdnet')
options = parser.parse_args()
noise_setting = options.noise_setting # 'gaussian_noise', 'scattering' num_angles = options.numangles # 50, 10, 5, 2 method = options.method # 'learnedpd', 'fbpunet', 'fbpmsdnet', 'cinn' name = 'lodopab{}_{}'.format(noise_setting, method)
from dival import get_standard_dataset from dival.measure import PSNR from dival.util.plot import plot_images import numpy as np IMPL = 'astra_cuda' datasets = get_standard_dataset('lodopab', impl=IMPL) test_data = datasets.get_data_pairs('test', 100)
FBP_DATASET_STATS = { 'noisefree': { 2: { 'mean_fbp': 0.0020300781237049294, 'std_fbp': 0.0036974098858769781, 'mean_gt': 0.0018248517968347585, 'std_gt': 0.0020251920919838714 }, 5: { 'mean_fbp': 0.0018914765285141003, 'std_fbp': 0.0027988724415204552, 'mean_gt': 0.0018248517968347585, 'std_gt': 0.0020251920919838714 }, 10: { 'mean_fbp': 0.0018791806499857538, 'std_fbp': 0.0023355593815585413, 'mean_gt': 0.0018248517968347585, 'std_gt': 0.0020251920919838714 }, 50: { 'mean_fbp': 0.0018856220845133943, 'std_fbp': 0.002038545754978578, 'mean_gt': 0.0018248517968347585, 'std_gt': 0.0020251920919838714 } }, 'gaussian_noise': { 2: { 'mean_fbp': 0.0020300515246877825, 'std_fbp': 0.01135122820016111, 'mean_gt': 0.0018248517968347585, 'std_gt': 0.0020251920919838714 }, 5: { 'mean_fbp': 0.0018914835384669934, 'std_fbp': 0.0073404856822226593, 'mean_gt': 0.0018248517968347585, 'std_gt': 0.0020251920919838714 }, 10: { 'mean_fbp': 0.0018791781748714272, 'std_fbp': 0.0053367740312729459, 'mean_gt': 0.0018248517968347585, 'std_gt': 0.0020251920919838714 }, 50: { 'mean_fbp': 0.0018856252771456445, 'std_fbp': 0.0029598508235758759, 'mean_gt': 0.0018248517968347585, 'std_gt': 0.0020251920919838714 } }, 'scattering': { 2: { 'mean_fbp': 0.68570249744436962, 'std_fbp': 1.3499668155231217, 'mean_gt': 0.002007653630624356, # different from gaussian_noise 'std_gt': 0.0019931366497635745 # since subset of slices is used }, 5: { 'mean_fbp': 0.67324839540841908, 'std_fbp': 0.99012416989800478, 'mean_gt': 0.002007653630624356, # different from gaussian_noise 'std_gt': 0.0019931366497635745 # since subset of slices is used }, 10: { 'mean_fbp': 0.66960775275347806, 'std_fbp': 0.80318946689776671, 'mean_gt': 0.002007653630624356, # different from gaussian_noise 'std_gt': 0.0019931366497635745 # since subset of slices is used }, 50: { 'mean_fbp': 0.67173917657611049, 'std_fbp': 0.6794825395874754, 'mean_gt': 0.002007653630624356, # different from gaussian_noise 'std_gt': 0.0019931366497635745 # since subset of slices is used } } }
ray_trafo = dataset.ray_trafo
assert MSD_PYTORCH_AVAILABLE reconstructor = FBPMSDNetReconstructor( ray_trafo, hyper_params={ 'depth': 100, 'width': 1, 'dilations': (1, 2, 3, 4, 5, 6, 7, 8, 9, 10), 'lr': 0.001, 'batch_size': 1, 'epochs': 50, 'data_augmentation': True, 'scheduler': 'none' }, save_best_learned_params_path=os.path.join(RESULTS_PATH, name), log_dir=os.path.join(RESULTS_PATH, name), num_data_loader_workers=0, )
reconstructor.save_hyper_params( os.path.join(RESULTS_PATH, '{}_hyper_params.json'.format(name)))
print("start training: '{}'".format(name)) print('hyper_params = {}'.format( json.dumps(reconstructor.hyper_params, indent=1))) reconstructor.train(dataset)
recos = [] psnrs = [] for obs, gt in test_data: reco = reconstructor.reconstruct(obs) recos.append(reco) psnrs.append(PSNR(reco, gt))
print('mean psnr: {:f}'.format(np.mean(psnrs))) import matplotlib.pyplot as plt for i in range(3): _, ax = plot_images([recos[i], test_data.ground_truth[i]], fig_size=(10, 4)) ax[0].set_xlabel('PSNR: {:.2f}'.format(psnrs[i])) ax[0].set_title('CINNReconstructor') ax[1].set_title('ground truth') ax[0].figure.suptitle('test sample {:d}'.format(i)) plt.show()`
but it raise error at this code line in 'odl_fourier_transform.py' and the error is below the original code about 'torch.fft.rfft(x_preproc,dim=-1)' is 'torch.rfft(x_preproc,1)' I modify it cause I use the pytorch1.9 version. is this line modifiation or my dataset loading cause the error ? and could u provide the ms_d_cnn network training file of lodopab dataset,many thx!