Closed BailiangJ closed 11 months ago
Hi Bailiang,
Thanks for your interest in our work!
To modify the code for your own dataset. You may follow the code in ./core/trainers/Brainweb/BrainWebGroupRegTrainer.py
to create a new trainer module. Particularly, you may first create a PyTorch Dataset module similar to ./core/data/data_providers/Brainweb/BrainwebImageProvider.py
to load your data. Then, create a model file similar to ./core/models/Brainweb/BrainWebRegModel.py
, which instantiate the registration method you want to use, and set the evaluation metrics. Finally, back to the trainer module to load your data and initialize the model and optimizer. Training can be done with
for j in range(model.reg.num_reg_levels):
model.reg.activate_params([j])
opt = opts[j]
for step in range(0 if j == 0 else cum_steps[j - 1], cum_steps[j]):
opt.zero_grad()
warped_images = model.reg()
loss = model.reg.loss_function(warped_images)
loss.backward()
opt.step()
You may also check ./core/configs/ConfigBrainWeb.py
to see what arguments you may need to configure the registration. For example, to use FFD you may need to set --ffd_spacing
and --zero_avg_flow
, while for SVF your may set --zero_avg_vec
. To use XCoRegUN/XCoRegGT, you may also need to tune the parameter --num_classes
. Note that for XCoRegGT you need also to load the anatomical labels for the corresponding input images.
If you have any further questions, please feel free to contact me.
Good luck!
Hi @xzluo97,
Thank you for the quick reply.
Yes, I am adopting the code from the BrainWeb example.
Regarding the parameters, I have a few questions.
From _verify_hyerparameters
, I can see transform_type
is a multiple choice from ['TRA', 'RIG', 'AFF', 'FFD', 'DDF', 'SVF']. Could it be ['FFD', 'DDF']? But I can see that 'FFD' has its own pre_ffd_level
which equals the length of pred_ffd_spacing
if using isotropic spacing across dimension, while 'DDF' is a full size vector field. From the code I also see that if I use ['FFD', 'SVF'], it's not meaning the velocity field is parametrized by B-Spline, is that correct?
What are the meanings of num_reg_levels
max_levels
and num_res
, do they related to the usual definition of pyramid levels
in conventional method? -> one level down it will be half size downsampled. Could you also explain the difference between reg_level
and res_level
?
If I understand correctly from the paper, using XCoRegUN, it will do the clustering with num_classes
for the appearance model. And if using XCoRegGT, it will estimate the appearance model based on the input segmentation, then should I also specify num_classes
for XCoRegGT?
The code is very well structured and the naming of the variables is informative, it helps a lot when I read them and understand them, maybe some explainatory comments would be a great addition.
Thanks a lot for your help and patience with my possible upcoming questions.:)
Cheers, Bailiang
Hi Bailiang,
Thank you for the questions. I will answer them point-by-point as follows:
transform_type
can be ['FFD', 'DDF']
, which means you have these two types of transformations composed together for the registration. Besides, by using the method self.activate_params(levels)
you can optimize them individually. Also, you are right that using transform_type=['FFD', 'SVF']
doesn't mean the velocity field is parameterized by B-Spline. However, you can easily set up for it with the current code.num_reg_levels
means the total number of transformations composed together for registration. However, max_reg_level
means the current highest registration level (higher means the registration is finer), which can be altered by the method self.activate_params(levels)
to tell the model not to compose any higher-level transformation that are not optimized when predicting the total transformation. On the other hand, num_res
is actually related to pyramid levels
you mentioned. For example, if you are using transform_type=['AFF', 'FFD']
, then setting num_res=[3,3,3,1,1,1]
means for AFF
transformation you will use three resolutions across all dimensions of the images to compute the similarity measure, while for FFD
you will use a single resolution. For each resolution, the original images will be downsampled by a factor of 2. More specifically, when using num_res=[3,3,3]
for AFF
and steps=300
, the images will first be downsampled by a factor of 4 for optimization during step=0:100
, then a factor of 2 during step=100:200
, and finally no downsampling during step=200:300
.num_classes
to be the same as the number of label classes in your input segmentations.Cheers, Xinzhe
Hi @xzluo97 ,
Thank you for your answers. Sorry to bother you again, I still have the following questions. I am only considering the flow field transformation.
For num_res
, it is only rescaling warped images to different resolution and compute the loss at different resolutions, but it's not generating the corresponding flow field for the different resolutions, it that correct? To put this in another word, it is only computing the full size flow field throughout the registration.
For num_ffd_levels
and pred_ffd_spacing
, it's actually similar to question 1, so if I have three sets of isotropic spaincg [[1],[3],[5]], the transformation would be 3 full size ffd flow field and the final predict_flow
would be the composition of these three?
What is the difference between parameters label
and mask
? I can see that label
will also be taken as weight
in SVCD
. So is mask
the mask of ROI? e.g. the mask of the brain area.
Are images tensors of shape (B, num_subjects, H,W,D) and labels -> one-hot of shape (B, num_subjects, num_classes, H,W,D)?
Thanks a lot for your help!
Cheers, Bailiang
I am attaching my modification of the model, I would really appreaciate it if you could check for me if you have time and see if something is missing due to my understanding (they are running good, except when I use XCoRegGT it will have cudaOOM, but it might be my num_classes is too big).
class MyRegModel(nn.Module):
def __init__(self,
dimension=3,
img_size=(160,160,192),
eps=1e-8,
**kwargs):
'''
num_classes
num_bins
sample_rate
kernel_sigma
eps
momentum
min_prob
'''
super(MyRegModel, self).__init__()
self.dimension = dimension
self.img_size = img_size
self.eps = eps
self.kwargs = kwargs
self.model_type = self.kwargs.pop('model_type', None)
assert self.model_type in ['XCoRegUn', 'XCoRegGT', 'GMM', 'APE', 'CTE']
if self.model_type == 'XCoRegUn':
Model = XCoRegUnRegModel.XCoRegUnRegModel
elif self.model_type == 'XCoRegGT':
Model = XCoRegGTRegModel.XCoRegGTRegModel
elif self.model_type == 'GMM':
Model = GMMRegModel.GMMRegModel
elif self.model_type == 'APE':
Model = APERegModel.APERegModel
elif self.model_type == 'CTE':
Model = CTERegModel.CTERegModel
else:
raise NotImplementedError
self.reg = Model(self.dimension, self.img_size, eps=self.eps, num_subjects=3, **self.kwargs)
self.num_subjects = self.reg.num_subjects
self.mask_sigma = self.kwargs.pop('mask_sigma', -1)
self.prior_sigma = self.kwargs.pop('prior_sigma', -1)
def init_model_params(self, images, label=None):
B = images.shape[0]
assert images.shape[1] == self.num_subjects
self.reg.init_reg_params(images=images)
if label is None:
prior = torch.full((B, self.reg.num_classes, *self.reg.img_size),
fill_value=1 / self.reg.num_classes, device=images.device, dtype=images.dtype)
mask = torch.ones(B, 1, *self.reg.img_size, dtype=images.dtype, device=images.device)
else:
if self.mask_sigma == -1:
mask = torch.ones(B, 1, *self.reg.img_size, dtype=images.dtype, device=images.device)
else:
mask = self._spatial_filter(label[:, 1:].sum(dim=1, keepdim=True),
utils.gauss_kernel1d(self.mask_sigma)).gt(self.eps).to(self.images.dtype)
if self.prior_sigma == -1:
prior = torch.full((B, self.reg.num_classes, *self.reg.img_size),
fill_value=1 / self.reg.num_classes, device=images.device, dtype=images.dtype)
else:
prior = utils.compute_normalized_prob(self._spatial_filter(label,
utils.gauss_kernel1d(self.prior_sigma)),
dim=1)
self.reg.mask = mask
self.reg.prior = prior
labels = torch.unbind(label, dim=1)
if self.model_type == 'XCoRegUn':
self.reg.init_app_params()
if self.model_type == 'XCoRegGT':
self.reg.init_app_params(labels=labels)
if self.model_type == 'GMM':
self.reg.init_gmm_params()
For the RegModel, I am just removing the part where you adding noise and generating ground truth misalignment in the BrainWeb Model.
class IterGroupRegTrainer(Trainer):
def train(self, images, label, device='cuda:0', **kwargs):
steps = kwargs.pop('steps', [50])
assert len(steps) == self.net.reg.num_reg_levels
cum_steps = np.cumsum(steps)
if isinstance(self.lr, float):
lr = [self.lr] * self.net.reg.num_reg_levels
elif isinstance(self.lr, (tuple, list)):
if len(self.lr) == 1:
lr = list(self.lr) * self.net.reg.num_reg_levels
else:
assert len(self.lr) == self.net.reg.num_reg_levels
lr = self.lr
else:
raise NotImplementedError
self.writer = self._get_writer(self.save_path)
# MyRegModel
model = self.net.to(device)
# model.reg -> XCoRegUn/XCoRegGT/APE/CTE
# images =
# label = # one-hot
model.init_model_params(images, label)
opts = [self._get_optimizer([model.reg.params[model.reg.reg_level_type[j]]], lr=lr[j])
for j in range(model.reg.num_reg_levels)]
for j in range(model.reg.num_reg_levels):
print('reg_level',j)
model.reg.activate_params([j])
opt = opts[j]
for step in range(0 if j == 0 else cum_steps[j - 1], cum_steps[j]):
print('step',step)
opt.zero_grad()
warped_images = model.reg()
loss = model.reg.loss_function(warped_images)
loss.backward()
opt.step()
# model.reg.num_reg_levels = 1
reg_flows = model.reg.predict_flows()
# save the reg_flows
# the reg_flows is mapping each subjects to the common space
if __name__ == '__main__':
# configurations and dataloading
images = ... # (B=1, num_subjects=3, 160,160,192)
labels = ... # (B=1, num_subjects=3, num_classes=20,160,160,192)
net = MyRegModel(
dimension=3,
img_size=(160,160,192),
num_classes=4,
model_type='XCoRegGT', # 'XCoRegGT', 'GMM', 'APE', 'CTE'
num_bins=64,
sample_rate=0.1,
kernel_sigma=1,
mask_sigma=-1,
prior_sigma=-1,
alpha=1,
transform_type=['SVF'], # 'DDF', 'FFD'
# pred_ffd_spacing=,
# pred_ffd_iso=,
group2ref=False,
# zero_avg_flow=True,
zero_avg_vec=True,
norm_img=False,
eps=1e-8,
)
trainer = IterGroupRegTrainer(net,
verbose=0,
save_path=save_path,
optimizer_name='Adam',
learning_rate=0.1,
weight_decay=0.0,
scheduler_name='CyclicLR',# None, 'OneCyclicLR'
max_lr=1e-4,
base_lr=1e-5,
logger=logger,)
trainer.train(
images=images,
label=label,
device=device,
steps=[10,],
num_workers=8,
)
Here is my running script. Do you have any suggestions for the parameters setting or which parameters should I pay more attention to when I want to get more accurate results? I am registering 3 brain T1 MRI (taken at different time) of the same patient.
Also, there is one tiny bug I have found when I run the code, since torch.unbind is removing the dimension we specify, so in GroupRegModel.forward()
the input image to the SpatialTransformer
will lack one extra dimension to have grid_sample
properly running. For example, after torch.unbind, it will give a tuple of tensors of shape (B, H, W, D) but SpatialTransformer
requires image of shape (B,1,H,W,D). :)
Thanks a lot for your help!
Cheers, Bailiang
Hi Bailiang,
Thank you for the questions. For the first 4 questions you posted, I will answer them point-by-point as follows:
num_res
only downsamples the warped images but not the transformation fields.mask
parameter is to confine the region of ROI samples used for density estimation. This is typically used for images with complex background. For brain images you may not need to set up for it because the background is much simpler.Cheers, Xinzhe
For the attached code for your own data, glad to see you run them well. But you may not need to set the parameter scheduler_name
as it may alter the step size unexpectedly. And since you run the registration with SVF
transformations, you may need to tune the parameter alpha
for velocity field regularization.
Cheers, Xinzhe
@xzluo97 Thank you so much for all the information,
the images should of shape (B, num_subjects, 1, H,W,D)
I see, that's the reason why I had problem with the SpatialTransformer
.
Yes, I have found out that BendingEnergy on the velocity field is giving staircase artifacts in the boundary region of the resampled images. So I also added the MembraneEnergy to GroupRegModel._get_regularization
.
def _get_regularization(self):
r = 0
for j in self.activated_reg_levels:
if self.reg_level_type[j] not in ['TRA', 'AFF', 'RIG']:
r += self.bending_energy(self.params[self.reg_level_type[j]]) * self.group_num
r += self.membrane_energy(self.params[self.reg_level_type[j]]) * self.group_num
return r
For the LR scheduler, would it help if I use normal exponential decay or it is fine to use a constant LR? (Or I just try it out by myself to see the results.)
Thanks again for your help. I think our conversations here would also be beneficial to other researchers who want to run the algorithms!:)
Cheers, Bailiang
Hi Bailiang,
I have only used the constant LR in my experiments and didn't try setting the LR scheduler, so I'm afraid I could not give you more advice on this. But I think using LR scheduler may help prevent local optima, depending on the landscape of the similarity measure w.r.t. the misalignment.
Glad to see our conversation helps you. Good luck to your experiments!
Best, Xinzhe
Hi @xzluo97 ,
Thank you for the great work and kindly sharing the code.
I am now trying to run the group-wise registration (FFD or SVF) on a set of three T1 brain 3D MRIs (.nii.gz file), and save the final displacement field.
If I want to use XCoRegUN and XCoRegGT (both without segmentation), how should I start?
I am a bit lost in the code and don't know where to modify, could you give me some hints/guidances?
Thanks a lot!