cwmok / C2FViT

This is the official Pytorch implementation of "Affine Medical Image Registration with Coarse-to-Fine Vision Transformer" (CVPR 2022), written by Tony C. W. Mok and Albert C. S. Chung.
MIT License
132 stars 4 forks source link

After training for time, the loss becomes 0 #11

Closed Lebesgue-zyker closed 11 months ago

Lebesgue-zyker commented 1 year ago

I used the linked data you provided to train Train_C2FViT_pairwise.py without modifying the code. At the beginning, loss was negative, but sometimes it would become positive and finally become 0. Do you know what happened? image

cwmok commented 1 year ago

Hi @Lebesgue-zyker ,

I never seen this before. My wild guess will be 1) It is caused by the cuda version of your pytorch. In my cases, I have encountered a similar problem when I use torch==2.0.0 with the default cuda version. Then, the training became stable when I switched to torch==2.0.0+cu117. 2)It is caused by the data. Did you point to the correct database you just downloaded?

Lebesgue-zyker commented 1 year ago

Hi @cwmok, I used torch=1.7.1, and today I changed torch to 2.0.0+cu117, and I still have the same result. This is a screenshot of my torch version image image I also examined the data and visualized the slice [128,:,:] in the x,y generated by the generator, as shown in the figure image

image So, there's no problem with the data.

I output the final affine matrix, and when loss becomes 0, the values of the affine matrix are in the region 1 or -1. image

Here is the code I used Train_C2FViT_pairwise.py:

import os
import glob
import sys
from argparse import ArgumentParser
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data
from C2FViT_model import C2F_ViT_stage, AffineCOMTransform, Center_of_mass_initial_pairwise, multi_resolution_NCC
from Functions import Dataset_epoch

def dice(im1, atlas):
    unique_class = np.unique(atlas)
    dice = 0
    num_count = 0
    for i in unique_class:
        if (i == 0) or ((im1 == i).sum() == 0) or ((atlas == i).sum() == 0):
            continue

        sub_dice = np.sum(atlas[im1 == i] == i) * 2.0 / (np.sum(im1 == i) + np.sum(atlas == i))
        dice += sub_dice
        num_count += 1
    return dice / num_count

def train():
    print("Training C2FViT...")
    model = C2F_ViT_stage(img_size=128, patch_size=[3, 7, 15], stride=[2, 4, 8], num_classes=12,
                          embed_dims=[256, 256, 256],
                          num_heads=[2, 2, 2], mlp_ratios=[2, 2, 2], qkv_bias=False, qk_scale=None, drop_rate=0.,
                          attn_drop_rate=0., norm_layer=nn.Identity,
                          depths=[4, 4, 4], sr_ratios=[1, 1, 1], num_stages=3, linear=False).cuda()

    # model = C2F_ViT_stage(img_size=128, patch_size=[7, 15], stride=[4, 8], num_classes=12, embed_dims=[256, 256],
    #                       num_heads=[2, 2], mlp_ratios=[2, 2], qkv_bias=False, qk_scale=None, drop_rate=0.,
    #                       attn_drop_rate=0., norm_layer=nn.Identity, depths=[4, 4], sr_ratios=[1, 1], num_stages=2,
    #                       linear=False).cuda()

    # model = C2F_ViT_stage(img_size=128, patch_size=[15], stride=[8], num_classes=12, embed_dims=[256],
    #                       num_heads=[2], mlp_ratios=[2], qkv_bias=False, qk_scale=None, drop_rate=0.,
    #                       attn_drop_rate=0., norm_layer=nn.Identity, depths=[4], sr_ratios=[1], num_stages=1,
    #                       linear=False).cuda()

    # print(sum(p.numel() for p in model.parameters() if p.requires_grad))

    affine_transform = AffineCOMTransform().cuda()
    init_center = Center_of_mass_initial_pairwise()

    loss_similarity = multi_resolution_NCC(win=7, scale=3)

    # OASIS
    imgs = sorted(glob.glob(datapath + "/OASIS_OAS1_*_MR1/norm.nii.gz"))
    labels = sorted(glob.glob(datapath + "/OASIS_OAS1_*_MR1/seg35.nii.gz"))

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    # optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    model_dir = '../Model/' + model_name[0:-1]

    if not os.path.isdir(model_dir):
        os.makedirs(model_dir)

    lossall = np.zeros((2, iteration + 1))

    training_generator = Data.DataLoader(Dataset_epoch(imgs, labels, norm=True, use_label=False),
                                         batch_size=1,
                                         shuffle=True, num_workers=4)
    step = 0
    load_model = False
    if load_model is True:
        model_path = "/home/z05979/image_registration/C2FViT-main/Model/C2FViT_affine_COM_pairwise_stagelvl3_118000.pth"
        print("Loading weight: ", model_path)
        step = 0
        model.load_state_dict(torch.load(model_path))
        temp_lossall = np.load("../Model/loss_LDR_LPBA_NCC_lap_share_preact_1_05_3000.npy")
        lossall[:, 0:3000] = temp_lossall[:, 0:3000]

    while step <= iteration:
        for X, Y in training_generator:

            X = X.cuda().float()
            Y = Y.cuda().float()
            # COM initialization
            if com_initial:
                X, _ = init_center(X, Y)

            X = F.interpolate(X, scale_factor=0.5, mode="trilinear", align_corners=True)
            Y = F.interpolate(Y, scale_factor=0.5, mode="trilinear", align_corners=True)

            warpped_x_list, y_list, affine_para_list = model(X, Y)
            print(affine_para_list[-1])
            # 3 level deep supervision NCC
            loss_multiNCC = loss_similarity(warpped_x_list[-1], y_list[-1])

            loss = loss_multiNCC

            optimizer.zero_grad()  # clear gradients for this training step
            loss.backward()  # backpropagation, compute gradients
            optimizer.step()  # apply gradients

            lossall[:, step] = np.array(
                [loss.item(), loss_multiNCC.item()])
            sys.stdout.write(
                "\r" + 'step "{0}" -> training loss "{1:.4f}" - sim_NCC "{2:4f}"'.format(
                    step, loss.item(), loss_multiNCC.item()))
            sys.stdout.flush()

            # with lr 1e-3 + with bias
            if (step % n_checkpoint == 0):
                modelname = model_dir + '/' + model_name + "stagelvl3_" + str(step) + '.pth'
                torch.save(model.state_dict(), modelname)
                np.save(model_dir + '/loss' + model_name + "stagelvl3_" + str(step) + '.npy', lossall)

                # Put your validation code here
                # ---------------------------------------

                # imgs = sorted(glob.glob(datapath + "/OASIS_OAS1_*_MR1/norm.nii.gz"))[255:259]
                # labels = sorted(glob.glob(datapath + "/OASIS_OAS1_*_MR1/seg35.nii.gz"))[255:259]
                #
                # valid_generator = Data.DataLoader(
                #     Dataset_epoch(imgs, labels, norm=True, use_label=True),
                #     batch_size=1,
                #     shuffle=False, num_workers=2)
                #
                # use_cuda = True
                # device = torch.device("cuda" if use_cuda else "cpu")
                # dice_total = []
                # brain_dice_total = []
                # print("\nValiding...")
                # for batch_idx, data in enumerate(valid_generator):
                #     X, Y, X_label, Y_label = data[0].to(device), data[1].to(device), data[2].to(
                #         device), data[3].to(device)
                #
                #     with torch.no_grad():
                #         if com_initial:
                #             X, init_flow = init_center(X, Y)
                #             X_label = F.grid_sample(X_label, init_flow, mode="nearest", align_corners=True)
                #
                #         X_down = F.interpolate(X, scale_factor=0.5, mode="trilinear", align_corners=True)
                #         Y_down = F.interpolate(Y, scale_factor=0.5, mode="trilinear", align_corners=True)
                #
                #         warpped_x_list, y_list, affine_para_list = model(X_down, Y_down)
                #         X_Y, affine_matrix = affine_transform(X, affine_para_list[-1])
                #         F_X_Y = F.affine_grid(affine_matrix, X_label.shape, align_corners=True)
                #
                #         X_Y_label = F.grid_sample(X_label, F_X_Y, mode="nearest", align_corners=True).cpu().numpy()[0,
                #                     0, :, :, :]
                #         X_brain_label = (X_Y > 0).float().cpu().numpy()[0, 0, :, :, :]
                #
                #         # brain mask
                #         Y_brain_label = (Y > 0).float().cpu().numpy()[0, 0, :, :, :]
                #         Y_label = Y_label.data.cpu().numpy()[0, 0, :, :, :]
                #
                #         dice_score = dice(np.floor(X_Y_label), np.floor(Y_label))
                #         dice_total.append(dice_score)
                #
                #         brain_dice = dice(np.floor(X_brain_label), np.floor(Y_brain_label))
                #         brain_dice_total.append(brain_dice)
                #
                # dice_total = np.array(dice_total)
                # brain_dice_total = np.array(brain_dice_total)
                # print("Dice mean: ", dice_total.mean())
                # print("Brain Dice mean: ", brain_dice_total.mean())
                #
                # with open(log_dir, "a") as log:
                #     log.write(f"{step}: {dice_total.mean()}, {brain_dice_total.mean()} \n")

            step += 1

            if step > iteration:
                break
        print("one epoch pass")
    np.save(model_dir + '/loss' + model_name + 'stagelvl3.npy', lossall)

if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument("--modelname", type=str,
                        dest="modelname",
                        default='C2FViT_affine_COM_pairwise_',
                        help="Model name")
    parser.add_argument("--lr", type=float,
                        dest="lr", default=1e-4, help="learning rate")
    parser.add_argument("--iteration", type=int,
                        dest="iteration", default=160001,
                        help="number of total iterations")
    parser.add_argument("--checkpoint", type=int,
                        dest="checkpoint", default=1000,
                        help="frequency of saving models")
    parser.add_argument("--datapath", type=str,
                        dest="datapath",
                        default='/home/works/ykzhang/img_registration/oasis',
                        help="data path for training images")
    parser.add_argument("--com_initial", type=bool,
                        dest="com_initial", default=True,
                        help="True: Enable Center of Mass initialization, False: Disable")
    opt = parser.parse_args()

    lr = opt.lr
    iteration = opt.iteration
    n_checkpoint = opt.checkpoint
    datapath = opt.datapath
    com_initial = opt.com_initial

    model_name = opt.modelname

    # Create and initalize log file
    if not os.path.isdir("../Log"):
        os.mkdir("../Log")

    log_dir = "../Log/" + model_name + ".txt"

    with open(log_dir, "a") as log:
        log.write("Validation Dice log for " + model_name[0:-1] + ":\n")

    print("Training %s ..." % model_name)
    print('torch.__version__', torch.__version__)
    print('torch.version.cuda', torch.version.cuda)
    print('torch.backends.cudnn.version()', torch.backends.cudnn.version())
    train()

C2FViT_model.py

import nibabel as nib
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
import numpy as np

from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model
from timm.models.vision_transformer import _cfg

import math

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.dwconv = DWConv(hidden_features)
        self.act1 = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
        self.act2 = act_layer()

    def forward(self, x, H, W, D):
        x = self.fc1(x)
        x = self.act1(x)
        x = self.dwconv(x, H, W, D)
        x = self.act2(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1,
                 linear=False):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."

        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.linear = linear
        self.sr_ratio = sr_ratio
        if not linear:
            if sr_ratio > 1:
                self.sr = nn.Conv3d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
                self.norm = nn.LayerNorm(dim)
        else:
            self.pool = nn.AdaptiveAvgPool3d(7)
            self.sr = nn.Conv3d(dim, dim, kernel_size=1, stride=1)
            self.norm = nn.LayerNorm(dim)
            self.act = nn.GELU()

    def forward(self, x, H, W, D):
        B, N, C = x.shape
        q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        if not self.linear:
            if self.sr_ratio > 1:
                x_ = x.permute(0, 2, 1).reshape(B, C, H, W, D)
                x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
                x_ = self.norm(x_)
                kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
            else:
                kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        else:
            x_ = x.permute(0, 2, 1).reshape(B, C, H, W, D)
            x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1)
            x_ = self.norm(x_)
            x_ = self.act(x_)
            kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x

class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, linear=False):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
            attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, linear=linear)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x, H, W, D):
        x = x + self.drop_path(self.attn(self.norm1(x), H, W, D))
        x = x + self.drop_path(self.mlp(self.norm2(x), H, W, D))

        return x

class DWConv(nn.Module):
    def __init__(self, dim=768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv3d(dim, dim, 3, 1, 1, bias=False, groups=dim)

    def forward(self, x, H, W, D):
        B, N, C = x.shape
        x = x.transpose(1, 2).view(B, C, H, W, D)
        x = self.dwconv(x)
        x = x.flatten(2).transpose(1, 2)

        return x

class OverlapPatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """

    def __init__(self, img_size=128, patch_size=7, stride=4, in_chans=3, embed_dim=768, flatten=True):
        super().__init__()
        img_size = (img_size, img_size, img_size)
        patch_size = (patch_size, patch_size, patch_size)

        self.img_size = img_size
        self.patch_size = patch_size
        self.H, self.W, self.D = img_size[0] // stride, img_size[1] // stride, img_size[2] // stride
        self.num_patches = self.H * self.W * self.D
        self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
                              padding=(patch_size[0] // 2, patch_size[1] // 2, patch_size[2] // 2))
        # self.norm = nn.LayerNorm(embed_dim)
        self.flatten = flatten

        self.act = nn.GELU()

    def forward(self, x):
        x = self.proj(x)
        _, _, H, W, D = x.shape
        if self.flatten:
            # BCHW -> BNC
            x = x.flatten(2).transpose(1, 2)
        # x = self.norm(x)
        x = self.act(x)

        return x, H, W, D

# From "Conditional Positional Encodings for Vision Transformers" by Chu et al., 2021
# https://github.com/Meituan-AutoML/Twins/blob/fa2f80e62794eaa55e2c1fbb41679a718ff642d9/segmentation/gvt.py
class PosCNN(nn.Module):
    def __init__(self, in_chans, embed_dim=768, s=1, k=3):
        super(PosCNN, self).__init__()
        self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=k, stride=s, padding=k//2, bias=False, groups=embed_dim)
        self.s = s

    def forward(self, x, H, W, D):
        B, N, C = x.shape
        feat_token = x
        cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W, D)
        if self.s == 1:
            x = self.proj(cnn_feat) + cnn_feat
        else:
            x = self.proj(cnn_feat)
        x = x.flatten(2).transpose(1, 2)
        return x

class C2F_ViT_stage(nn.Module):
    def __init__(self, img_size=128, patch_size=[3, 7, 15], stride=[2, 4, 8], num_classes=12, embed_dims=[256, 256, 256],
                 num_heads=[2, 2, 2], mlp_ratios=[2, 2, 2], qkv_bias=False, qk_scale=None, drop_rate=0.,
                 attn_drop_rate=0., norm_layer=nn.Identity, depths=[4, 4, 4], sr_ratios=[1, 1, 1], num_stages=3, linear=False):
        super().__init__()
        self.num_classes = num_classes
        self.depths = depths
        self.num_stages = num_stages

        for i in range(self.num_stages):
            patch_embed = OverlapPatchEmbed(img_size=img_size,
                                                  patch_size=patch_size[i],
                                                  stride=stride[i],
                                                  in_chans=2,
                                                  embed_dim=embed_dims[i])
            stage = nn.ModuleList([
                                         Block(dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratios[i],
                                               qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate,
                                               attn_drop=attn_drop_rate, drop_path=0, norm_layer=norm_layer,
                                               sr_ratio=sr_ratios[i], linear=linear) for _ in range(depths[i])])

            head = nn.Sequential(
            nn.Linear(embed_dims[i], embed_dims[i] // 2, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(embed_dims[i] // 2, num_classes, bias=False),
            nn.Tanh()
            )

            setattr(self, f"patch_embed_{i + 1}_xy", patch_embed)
            setattr(self, f"stage_{i + 1}", stage)
            setattr(self, f"head_{i+1}", head)

        for i in range(self.num_stages-1):
            squeeze = nn.Conv3d(embed_dims[i], embed_dims[i + 1], kernel_size=3, stride=1, padding=1)
            setattr(self, f"squeeze_{i + 1}", squeeze)

        self.avg_pool = nn.AvgPool3d(2, 2)
        self.affine_transform = AffineCOMTransform()

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv3d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def image_pyramid(self, x, level=3):
        out = [x]
        for i in range(level - 1):
            x = self.avg_pool(x)
            out.append(x)

        return out[::-1]

    def forward(self, x, y):
        B = x.shape[0]

        warpped_x_list = []
        affine_list = []

        x = self.image_pyramid(x, self.num_stages)
        y = self.image_pyramid(y, self.num_stages)

        for i in range(self.num_stages):
            if i == 0:
                xy = torch.cat([x[i], y[i]], dim=1)
            else:
                xy = torch.cat([warpped_x_list[i-1], y[i]], dim=1)

            patch_embed_xy = getattr(self, f"patch_embed_{i + 1}_xy")
            xy_patch_embed, H, W, D = patch_embed_xy(xy)

            if i > 0:
                xy_patch_embed = xy_patch_embed + xy_fea

            xy_fea = xy_patch_embed
            stage_block = getattr(self, f"stage_{i + 1}")
            for blk in stage_block:
                xy_fea = blk(xy_fea, H, W, D)

            head = getattr(self, f"head_{i + 1}")
            affine = head(xy_fea.mean(dim=1))
            affine_list.append(affine)

            if i < self.num_stages - 1:
                warpped_x, _ = self.affine_transform(x[i + 1], affine)
                warpped_x_list.append(warpped_x)

                xy_fea = xy_fea.reshape(B, H, W, D, -1).permute(0, 4, 1, 2, 3)
                squeeze = getattr(self, f"squeeze_{i + 1}")
                xy_fea = squeeze(xy_fea).flatten(2).transpose(1, 2)
            else:
                warpped_x, _ = self.affine_transform(x[i], affine)
                warpped_x_list.append(warpped_x)

        return warpped_x_list, y, affine_list

class C2F_ViT_stage_pos(nn.Module):
    def __init__(self, img_size=128, patch_size=[3, 7, 15], stride=[2, 4, 8], num_classes=12, embed_dims=[256, 256, 256],
                 num_heads=[2, 2, 2], mlp_ratios=[2, 2, 2], qkv_bias=False, qk_scale=None, drop_rate=0.,
                 attn_drop_rate=0., norm_layer=nn.Identity, depths=[4, 4, 4], sr_ratios=[1, 1, 1], num_stages=3, linear=False):
        super().__init__()
        self.num_classes = num_classes
        self.depths = depths
        self.num_stages = num_stages

        for i in range(self.num_stages):
            patch_embed = OverlapPatchEmbed(img_size=img_size,
                                                  patch_size=patch_size[i],
                                                  stride=stride[i],
                                                  in_chans=2,
                                                  embed_dim=embed_dims[i])
            stage = nn.ModuleList([
                                         Block(dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratios[i],
                                               qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate,
                                               attn_drop=attn_drop_rate, drop_path=0, norm_layer=norm_layer,
                                               sr_ratio=sr_ratios[i], linear=linear) for _ in range(depths[i])])

            head = nn.Sequential(
            nn.Linear(embed_dims[i], embed_dims[i] // 2, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(embed_dims[i] // 2, num_classes, bias=False),
            nn.Tanh()
            )

            i_imgsize = img_size//(2**(num_stages-1-i))
            pos_embed = nn.Parameter(torch.zeros(1, (i_imgsize//stride[i]) ** 3, embed_dims[i]))
            trunc_normal_(pos_embed, std=0.02)

            setattr(self, f"patch_embed_{i + 1}_xy", patch_embed)
            setattr(self, f"stage_{i + 1}", stage)
            setattr(self, f"head_{i + 1}", head)
            setattr(self, f"pos_embed_{i + 1}", pos_embed)

        for i in range(self.num_stages-1):
            squeeze = nn.Conv3d(embed_dims[i], embed_dims[i + 1], kernel_size=3, stride=1, padding=1)
            setattr(self, f"squeeze_{i + 1}", squeeze)

        self.avg_pool = nn.AvgPool3d(2, 2)
        self.affine_transform = AffineCOMTransform()

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv3d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def image_pyramid(self, x, level=3):
        out = [x]
        for i in range(level - 1):
            x = self.avg_pool(x)
            out.append(x)

        return out[::-1]

    def forward(self, x, y):
        B = x.shape[0]

        warpped_x_list = []
        affine_list = []

        x = self.image_pyramid(x, self.num_stages)
        y = self.image_pyramid(y, self.num_stages)

        for i in range(self.num_stages):
            if i == 0:
                xy = torch.cat([x[i], y[i]], dim=1)
            else:
                xy = torch.cat([warpped_x_list[i-1], y[i]], dim=1)

            patch_embed_xy = getattr(self, f"patch_embed_{i + 1}_xy")
            xy_patch_embed, H, W, D = patch_embed_xy(xy)

            # position embedding
            pos_embed = getattr(self, f"pos_embed_{i + 1}")
            xy_patch_embed = xy_patch_embed + pos_embed

            if i > 0:
                xy_patch_embed = xy_patch_embed + xy_fea

            xy_fea = xy_patch_embed
            stage_block = getattr(self, f"stage_{i + 1}")
            for blk in stage_block:
                xy_fea = blk(xy_fea, H, W, D)

            head = getattr(self, f"head_{i+1}")
            affine = head(xy_fea.mean(dim=1))
            affine_list.append(affine)

            if i < self.num_stages - 1:
                warpped_x, _ = self.affine_transform(x[i + 1], affine)
                warpped_x_list.append(warpped_x)

                xy_fea = xy_fea.reshape(B, H, W, D, -1).permute(0, 4, 1, 2, 3)
                squeeze = getattr(self, f"squeeze_{i + 1}")
                xy_fea = squeeze(xy_fea).flatten(2).transpose(1, 2)
            else:
                warpped_x, _ = self.affine_transform(x[i], affine)
                warpped_x_list.append(warpped_x)

        return warpped_x_list, y, affine_list

class C2F_ViT_stage_peg(nn.Module):
    def __init__(self, img_size=128, patch_size=[3, 7, 15], stride=[2, 4, 8], num_classes=12, embed_dims=[256, 256, 256],
                 num_heads=[2, 2, 2], mlp_ratios=[2, 2, 2], qkv_bias=False, qk_scale=None, drop_rate=0.,
                 attn_drop_rate=0., norm_layer=nn.Identity, depths=[4, 4, 4], sr_ratios=[1, 1, 1], num_stages=3, linear=False):
        super().__init__()
        self.num_classes = num_classes
        self.depths = depths
        self.num_stages = num_stages

        for i in range(self.num_stages):
            patch_embed = OverlapPatchEmbed(img_size=img_size,
                                                  patch_size=patch_size[i],
                                                  stride=stride[i],
                                                  in_chans=2,
                                                  embed_dim=embed_dims[i])
            stage = nn.ModuleList([
                                         Block(dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratios[i],
                                               qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate,
                                               attn_drop=attn_drop_rate, drop_path=0, norm_layer=norm_layer,
                                               sr_ratio=sr_ratios[i], linear=linear) for _ in range(depths[i])])

            head = nn.Sequential(
            nn.Linear(embed_dims[i], embed_dims[i] // 2, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(embed_dims[i] // 2, num_classes, bias=False),
            nn.Tanh()
            )

            pos_cnn = PosCNN(embed_dims[i], embed_dims[i], k=patch_size[i])

            setattr(self, f"patch_embed_{i + 1}_xy", patch_embed)
            setattr(self, f"stage_{i + 1}", stage)
            setattr(self, f"head_{i + 1}", head)
            setattr(self, f"pos_cnn_{i + 1}", pos_cnn)

        for i in range(self.num_stages-1):
            squeeze = nn.Conv3d(embed_dims[i], embed_dims[i + 1], kernel_size=3, stride=1, padding=1)
            setattr(self, f"squeeze_{i + 1}", squeeze)

        self.avg_pool = nn.AvgPool3d(2, 2)
        self.affine_transform = AffineCOMTransform()

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv3d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def image_pyramid(self, x, level=3):
        out = [x]
        for i in range(level - 1):
            x = self.avg_pool(x)
            out.append(x)

        return out[::-1]

    def forward(self, x, y):
        B = x.shape[0]

        warpped_x_list = []
        affine_list = []

        x = self.image_pyramid(x, self.num_stages)
        y = self.image_pyramid(y, self.num_stages)

        for i in range(self.num_stages):
            if i == 0:
                xy = torch.cat([x[i], y[i]], dim=1)
            else:
                xy = torch.cat([warpped_x_list[i - 1], y[i]], dim=1)

            patch_embed_xy = getattr(self, f"patch_embed_{i + 1}_xy")
            xy_patch_embed, H, W, D = patch_embed_xy(xy)

            if i > 0:
                xy_patch_embed = xy_patch_embed + xy_fea

            xy_fea = xy_patch_embed
            stage_block = getattr(self, f"stage_{i + 1}")
            for index, blk in enumerate(stage_block):
                xy_fea = blk(xy_fea, H, W, D)
                if index == 0:
                    pos_cnn = getattr(self, f"pos_cnn_{i + 1}")
                    xy_fea = pos_cnn(xy_fea, H, W, D)

            head = getattr(self, f"head_{i+1}")
            affine = head(xy_fea.mean(dim=1))
            affine_list.append(affine)

            if i < self.num_stages - 1:
                warpped_x, _ = self.affine_transform(x[i + 1], affine)
                warpped_x_list.append(warpped_x)

                xy_fea = xy_fea.reshape(B, H, W, D, -1).permute(0, 4, 1, 2, 3)
                squeeze = getattr(self, f"squeeze_{i + 1}")
                xy_fea = squeeze(xy_fea).flatten(2).transpose(1, 2)
            else:
                warpped_x, _ = self.affine_transform(x[i], affine)
                warpped_x_list.append(warpped_x)

        return warpped_x_list, y, affine_list

class AffineCOMTransform(nn.Module):
    def __init__(self, use_com=True):
        super(AffineCOMTransform, self).__init__()

        self.translation_m = None
        self.rotation_x = None
        self.rotation_y = None
        self.rotation_z = None
        self.rotation_m = None
        self.shearing_m = None
        self.scaling_m = None

        self.id = torch.zeros((1, 3, 4)).cuda()
        self.id[0, 0, 0] = 1
        self.id[0, 1, 1] = 1
        self.id[0, 2, 2] = 1

        self.use_com = use_com

    def forward(self, x, affine_para):
        # Matrix that register x to its center of mass
        id_grid = F.affine_grid(self.id, x.shape, align_corners=True)

        to_center_matrix = torch.eye(4).cuda()
        reversed_to_center_matrix = torch.eye(4).cuda()
        if self.use_com:
            x_sum = torch.sum(x)
            center_mass_x = torch.sum(x.permute(0, 2, 3, 4, 1)[..., 0] * id_grid[..., 0]) / x_sum
            center_mass_y = torch.sum(x.permute(0, 2, 3, 4, 1)[..., 0] * id_grid[..., 1]) / x_sum
            center_mass_z = torch.sum(x.permute(0, 2, 3, 4, 1)[..., 0] * id_grid[..., 2]) / x_sum

            to_center_matrix[0, 3] = center_mass_x
            to_center_matrix[1, 3] = center_mass_y
            to_center_matrix[2, 3] = center_mass_z
            reversed_to_center_matrix[0, 3] = -center_mass_x
            reversed_to_center_matrix[1, 3] = -center_mass_y
            reversed_to_center_matrix[2, 3] = -center_mass_z

        self.translation_m = torch.eye(4).cuda()
        self.rotation_x = torch.eye(4).cuda()
        self.rotation_y = torch.eye(4).cuda()
        self.rotation_z = torch.eye(4).cuda()
        self.rotation_m = torch.eye(4).cuda()
        self.shearing_m = torch.eye(4).cuda()
        self.scaling_m = torch.eye(4).cuda()

        trans_xyz = affine_para[0, 0:3]
        rotate_xyz = affine_para[0, 3:6] * math.pi
        shearing_xyz = affine_para[0, 6:9] * math.pi
        scaling_xyz = 1 + (affine_para[0, 9:12] * 0.5)

        self.translation_m[0, 3] = trans_xyz[0]
        self.translation_m[1, 3] = trans_xyz[1]
        self.translation_m[2, 3] = trans_xyz[2]
        self.scaling_m[0, 0] = scaling_xyz[0]
        self.scaling_m[1, 1] = scaling_xyz[1]
        self.scaling_m[2, 2] = scaling_xyz[2]

        self.rotation_x[1, 1] = torch.cos(rotate_xyz[0])
        self.rotation_x[1, 2] = -torch.sin(rotate_xyz[0])
        self.rotation_x[2, 1] = torch.sin(rotate_xyz[0])
        self.rotation_x[2, 2] = torch.cos(rotate_xyz[0])

        self.rotation_y[0, 0] = torch.cos(rotate_xyz[1])
        self.rotation_y[0, 2] = torch.sin(rotate_xyz[1])
        self.rotation_y[2, 0] = -torch.sin(rotate_xyz[1])
        self.rotation_y[2, 2] = torch.cos(rotate_xyz[1])

        self.rotation_z[0, 0] = torch.cos(rotate_xyz[2])
        self.rotation_z[0, 1] = -torch.sin(rotate_xyz[2])
        self.rotation_z[1, 0] = torch.sin(rotate_xyz[2])
        self.rotation_z[1, 1] = torch.cos(rotate_xyz[2])

        self.rotation_m = torch.mm(torch.mm(self.rotation_z, self.rotation_y), self.rotation_x)

        self.shearing_m[0, 1] = shearing_xyz[0]
        self.shearing_m[0, 2] = shearing_xyz[1]
        self.shearing_m[1, 2] = shearing_xyz[2]

        output_affine_m = torch.mm(to_center_matrix, torch.mm(self.shearing_m, torch.mm(self.scaling_m,
                                                                                        torch.mm(self.rotation_m,
                                                                                                 torch.mm(
                                                                                                     reversed_to_center_matrix,
                                                                                                     self.translation_m)))))
        grid = F.affine_grid(output_affine_m[0:3].unsqueeze(0), x.shape, align_corners=True)
        transformed_x = F.grid_sample(x, grid, mode='bilinear', align_corners=True)

        return transformed_x, output_affine_m[0:3].unsqueeze(0)

class DirectAffineTransform(nn.Module):
    def __init__(self):
        super(DirectAffineTransform, self).__init__()

        self.id = torch.zeros((1, 3, 4)).cuda()
        self.id[0, 0, 0] = 1
        self.id[0, 1, 1] = 1
        self.id[0, 2, 2] = 1

    def forward(self, x, affine_para):
        affine_matrix = affine_para.reshape(1, 3, 4) + self.id

        grid = F.affine_grid(affine_matrix, x.shape, align_corners=True)
        transformed_x = F.grid_sample(x, grid, mode='bilinear', align_corners=True)

        return transformed_x, affine_matrix

class Center_of_mass_initial_pairwise(nn.Module):
    def __init__(self):
        super(Center_of_mass_initial_pairwise, self).__init__()
        self.id = torch.zeros((1, 3, 4)).cuda()
        self.id[0, 0, 0] = 1
        self.id[0, 1, 1] = 1
        self.id[0, 2, 2] = 1

        self.to_center_matrix = torch.zeros((1, 3, 4)).cuda()
        self.to_center_matrix[0, 0, 0] = 1
        self.to_center_matrix[0, 1, 1] = 1
        self.to_center_matrix[0, 2, 2] = 1

    def forward(self, x, y):
        # center of mass of x -> center of mass of y
        id_grid = F.affine_grid(self.id, x.shape, align_corners=True)
        # mask = (x > 0).float()
        # mask_sum = torch.sum(mask)
        x_sum = torch.sum(x)
        x_center_mass_x = torch.sum(x.permute(0, 2, 3, 4, 1)[..., 0] * id_grid[..., 0])/x_sum
        x_center_mass_y = torch.sum(x.permute(0, 2, 3, 4, 1)[..., 0] * id_grid[..., 1])/x_sum
        x_center_mass_z = torch.sum(x.permute(0, 2, 3, 4, 1)[..., 0] * id_grid[..., 2])/x_sum

        y_sum = torch.sum(y)
        y_center_mass_x = torch.sum(y.permute(0, 2, 3, 4, 1)[..., 0] * id_grid[..., 0]) / y_sum
        y_center_mass_y = torch.sum(y.permute(0, 2, 3, 4, 1)[..., 0] * id_grid[..., 1]) / y_sum
        y_center_mass_z = torch.sum(y.permute(0, 2, 3, 4, 1)[..., 0] * id_grid[..., 2]) / y_sum

        self.to_center_matrix[0, 0, 3] = x_center_mass_x - y_center_mass_x
        self.to_center_matrix[0, 1, 3] = x_center_mass_y - y_center_mass_y
        self.to_center_matrix[0, 2, 3] = x_center_mass_z - y_center_mass_z

        grid = F.affine_grid(self.to_center_matrix, x.shape, align_corners=True)
        transformed_image = F.grid_sample(x, grid, align_corners=True)

        # print(affine_para)
        # print(output_affine_m[0:3])

        return transformed_image, grid

class NCC(torch.nn.Module):
    """
    local (over window) normalized cross correlation
    """
    def __init__(self, win=7, eps=1e-5):
        super(NCC, self).__init__()
        self.win = win
        self.eps = eps
        self.w_temp = win

    def forward(self, I, J):
        ndims = 3
        win_size = self.w_temp

        # set window size
        if self.win is None:
            self.win = [5] * ndims
        else:
            self.win = [self.w_temp] * ndims

        weight_win_size = self.w_temp
        weight = torch.ones((1, 1, weight_win_size, weight_win_size, weight_win_size), device=I.device, requires_grad=False)
        conv_fn = F.conv3d

        # compute CC squares
        I2 = I*I
        J2 = J*J
        IJ = I*J

        # compute filters
        # compute local sums via convolution
        I_sum = conv_fn(I, weight, padding=int(win_size/2))
        J_sum = conv_fn(J, weight, padding=int(win_size/2))
        I2_sum = conv_fn(I2, weight, padding=int(win_size/2))
        J2_sum = conv_fn(J2, weight, padding=int(win_size/2))
        IJ_sum = conv_fn(IJ, weight, padding=int(win_size/2))

        # compute cross correltorch. Sin     win_size = np.prod(self.win)
        u_I = I_sum/win_size
        u_J = J_sum/win_size

        cross = IJ_sum - u_J*I_sum - u_I*J_sum + u_I*u_J*win_size
        I_var = I2_sum - 2 * u_I * I_sum + u_I*u_I*win_size
        J_var = J2_sum - 2 * u_J * J_sum + u_J*u_J*win_size

        cc = cross * cross / (I_var * J_var + self.eps)

        # return negative cc.
        return -1.0 * torch.mean(cc)

class multi_resolution_NCC(torch.nn.Module):
    """
    local (over window) normalized cross correlation
    """
    def  __init__(self, win=None, eps=1e-5, scale=3, kernel=3):
        super(multi_resolution_NCC, self).__init__()
        self.num_scale = scale
        self.kernel = kernel
        self.similarity_metric = []

        for i in range(scale):
            self.similarity_metric.append(NCC(win=win - (i*2)))
            # self.similarity_metric.append(Normalized_Gradient_Field(eps=0.01))

    def forward(self, I, J):
        total_NCC = []
        for i in range(self.num_scale):
            current_NCC = self.similarity_metric[i](I, J)
            total_NCC.append(current_NCC/(2**i))
            # print(scale_I.size(), scale_J.size())

            I = nn.functional.avg_pool3d(I, kernel_size=self.kernel, stride=2, padding=self.kernel//2, count_include_pad=False)
            J = nn.functional.avg_pool3d(J, kernel_size=self.kernel, stride=2, padding=self.kernel//2, count_include_pad=False)

        return sum(total_NCC)

class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()
        self.eps = 1e-5

    def forward(self, input, target):
        N = target.size(0)

        input_flat = input.view(N, -1)
        target_flat = target.view(N, -1)

        intersection = input_flat * target_flat

        loss = 2. * (intersection.sum(1) + self.eps) / (input_flat.sum(1) + target_flat.sum(1) + self.eps)
        loss = 1. - loss.sum() / N

        return loss

class MulticlassDiceLossVectorize(nn.Module):
    """
    requires one hot encoded target. Applies DiceLoss on each class iteratively.
    requires input.shape[0:1] and target.shape[0:1] to be (N, C) where N is
      batch size and C is number of classes
    """
    def __init__(self):
        super(MulticlassDiceLossVectorize, self).__init__()
        self.Diceloss = DiceLoss()
        self.eps = 1e-5

    def forward(self, input, target):
        N, C, H, W, D = input.shape
        input_flat = input.view(N, C, -1)
        target_flat = target.view(N, C, -1)

        intersection = input_flat * target_flat
        loss = 2. * (torch.sum(intersection, dim=-1) + self.eps) / (torch.sum(input_flat, dim=-1) + torch.sum(target_flat, dim=-1) + self.eps)
        loss = 1. - torch.mean(loss, dim=-1)

        return torch.mean(loss)

Functions.py

import numpy as np
import itertools

import nibabel as nib
import numpy as np
import torch
import torch.utils.data as Data
import csv
import torch.nn.functional as F
from PIL import Image

def load_4D(name):
    # X = sitk.GetArrayFromImage(sitk.ReadImage(name, sitk.sitkFloat32 ))
    # X = np.reshape(X, (1,)+ X.shape)
    X = nib.load(name)
    X = X.get_fdata()
    X = np.reshape(X, (1,) + X.shape)
    return X

def load_4D_channel(name):
    X = nib.load(name)
    X = X.get_fdata()
    X = np.transpose(X, (3, 0, 1, 2))
    return X

def min_max_norm(img):
    max = np.max(img)
    min = np.min(img)

    norm_img = (img - min) / (max - min)

    return norm_img

def save_img(I_img, savename, header=None, affine=None):
    if header is None or affine is None:
        affine = np.diag([1, 1, 1, 1])
        new_img = nib.nifti1.Nifti1Image(I_img, affine, header=None)
    else:
        new_img = nib.nifti1.Nifti1Image(I_img, affine, header=header)

    nib.save(new_img, savename)

def save_flow(I_img, savename, header=None, affine=None):
    if header is None or affine is None:
        affine = np.diag([1, 1, 1, 1])
        new_img = nib.nifti1.Nifti1Image(I_img, affine, header=None)
    else:
        new_img = nib.nifti1.Nifti1Image(I_img, affine, header=header)

    nib.save(new_img, savename)

class Dataset_epoch(Data.Dataset):
    'Characterizes a dataset for PyTorch'
    def __init__(self, names, labels, norm=True, use_label=False):
        'Initialization'
        self.names = names
        self.labels = labels
        self.norm = norm
        self.index_pair = list(itertools.permutations(names, 2))
        self.index_pair_label = list(itertools.permutations(labels, 2))
        self.use_label = use_label

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.index_pair)

    def __getitem__(self, step):
        'Generates one sample of data'
        # Select sample
        img_A = load_4D(self.index_pair[step][0])
        img_B = load_4D(self.index_pair[step][1])

        img_A_label = load_4D(self.index_pair_label[step][0])
        img_B_label = load_4D(self.index_pair_label[step][1])

        if self.norm:
            img_A = min_max_norm(img_A)
            img_B = min_max_norm(img_B)
        if self.use_label:
            return torch.from_numpy(img_A).float(), torch.from_numpy(img_B).float(), torch.from_numpy(img_A_label).float(), torch.from_numpy(img_B_label).float()
        else:
            return torch.from_numpy(img_A).float(), torch.from_numpy(img_B).float()

class Dataset_epoch_MNI152(Data.Dataset):
    'Characterizes a dataset for PyTorch'

    def __init__(self, img_list, label_list, fixed_img, fixed_label, need_label=True):
        'Initialization'
        super(Dataset_epoch_MNI152, self).__init__()
        # self.exp_path = exp_path
        self.img_pair = img_list
        self.label_pair = label_list
        self.need_label = need_label
        self.fixed_img = fixed_img
        self.fixed_label = fixed_label

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.img_pair)

    def __getitem__(self, step):
        'Generates one sample of data'
        # Select sample
        moving_img = load_4D(self.img_pair[step])
        fixed_img = load_4D(self.fixed_img)
        fixed_img = np.clip(fixed_img, a_min=2500, a_max=np.max(fixed_img))

        if self.need_label:
            moving_label = load_4D(self.label_pair[step])
            fixed_label = load_4D(self.fixed_label)
            return torch.from_numpy(min_max_norm(moving_img)).float(), torch.from_numpy(
                min_max_norm(fixed_img)).float(), torch.from_numpy(moving_label).float(), torch.from_numpy(fixed_label).float()
        else:
            return torch.from_numpy(min_max_norm(moving_img)).float(), torch.from_numpy(
                min_max_norm(fixed_img)).float()

class Dataset_epoch_MNI152_pre_one_hot(Data.Dataset):
    'Characterizes a dataset for PyTorch'

    def __init__(self, img_list, label_list, fixed_img, fixed_label, need_label=True):
        'Initialization'
        super(Dataset_epoch_MNI152_pre_one_hot, self).__init__()
        # self.exp_path = exp_path
        self.img_pair = img_list
        self.label_pair = label_list
        self.need_label = need_label
        self.fixed_img = fixed_img
        self.fixed_label = fixed_label

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.img_pair)

    def __getitem__(self, step):
        'Generates one sample of data'
        # Select sample
        moving_img = load_4D(self.img_pair[step])
        fixed_img = load_4D(self.fixed_img)
        fixed_img = np.clip(fixed_img, a_min=2500, a_max=np.max(fixed_img))

        if self.need_label:
            moving_label = load_4D_channel(self.label_pair[step])
            fixed_label = load_4D_channel(self.fixed_label)

            return torch.from_numpy(min_max_norm(moving_img)).float(), torch.from_numpy(
                min_max_norm(fixed_img)).float(), torch.from_numpy(moving_label).float(), torch.from_numpy(fixed_label).float()
        else:
            return torch.from_numpy(min_max_norm(moving_img)).float(), torch.from_numpy(
                min_max_norm(fixed_img)).float()

class Dataset_epoch_onehot(Data.Dataset):
    'Characterizes a dataset for PyTorch'
    def __init__(self, names, labels, norm=True, use_label=False):
        'Initialization'
        self.names = names
        self.labels = labels
        self.norm = norm
        self.index_pair = list(itertools.permutations(names, 2))
        self.index_pair_label = list(itertools.permutations(labels, 2))
        self.use_label = use_label

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.index_pair)

    def __getitem__(self, step):
        'Generates one sample of data'
        # Select sample
        img_A = load_4D(self.index_pair[step][0])
        img_B = load_4D(self.index_pair[step][1])

        img_A_label = load_4D_channel(self.index_pair_label[step][0])
        img_B_label = load_4D_channel(self.index_pair_label[step][1])

        if self.norm:
            img_A = min_max_norm(img_A)
            img_B = min_max_norm(img_B)

        if self.use_label:
            return torch.from_numpy(img_A).float(), torch.from_numpy(img_B).float(), torch.from_numpy(img_A_label).float(), torch.from_numpy(img_B_label).float()
        else:
            return torch.from_numpy(img_A).float(), torch.from_numpy(img_B).float()

Is there something wrong with my operation? Looking forward to your reply.

cwmok commented 1 year ago

The output affine matrix is not normal. Could you try to replace the similarly function from loss_similarity = multi_resolution_NCC(win=7, scale=3) to loss_similarity = NCC(win=5)?

I will try the above code later to see if I can replicate the mistake. I am currently attending MICCAI conference.

Lebesgue-zyker commented 1 year ago

It looks like the problem is with the loss function, which I replaced with MAE to train correctly. But why does the model converge so slowly, is it a problem with MSE? I have trained 80,000 iterations, and I have not achieved the effect of 18,000 iterations of pre-training models. I'll try your solution today. Thank you for your reply and good luck with your conference!

cwmok commented 1 year ago

Hi @Lebesgue-zyker,

My model works fine with NCC loss. However, I noticed that some users also reported the instability issue with NCC loss, as mentioned in https://github.com/Project-MONAI/MONAI/discussions/3463.

In my own experiments, I also found that using NCC will converge faster and better than MAE/MSE loss.

Perhaps, you should try another implementation of NCC such as the one in https://github.com/Project-MONAI/MONAI/blob/8a70678baa976a01274f55837e339f2c2975dec7/monai/losses/image_dissimilarity.py#L51 or the one in https://github.com/uncbiag/ICON/blob/66ad08456d6676310b61a1120f21c4626fef1604/src/icon_registration/losses.py#L578. They both works fine in my environment.

wydilearn commented 12 months ago

Hi @Lebesgue-zyker,

My model works fine with NCC loss. However, I noticed that some users also reported the instability issue with NCC loss, as mentioned in Project-MONAI/MONAI#3463.

In my own experiments, I also found that using NCC will converge faster and better than MAE/MSE loss.

Perhaps, you should try another implementation of NCC such as the one in https://github.com/Project-MONAI/MONAI/blob/8a70678baa976a01274f55837e339f2c2975dec7/monai/losses/image_dissimilarity.py#L51 or the one in https://github.com/uncbiag/ICON/blob/66ad08456d6676310b61a1120f21c4626fef1604/src/icon_registration/losses.py#L578. They both works fine in my environment.

Hi @cwmok, I have tried all the solutions you provided. But I still can't deal with the problem well. Can you provide more better suggestions or codes? Thank you for your project and I'm looking forward to receiving your reply. 屏幕截图 2023-11-02 194917

cwmok commented 12 months ago

Hi @wydilearn and @Lebesgue-zyker,

I finally figured out the problem. The problem was caused by a mistake when I tried to update the NCC loss a few months ago. I mistakenly commented out part of the code, as shown below: image

I have updated the code. The loss function should work from now on. Sorry for the mistake, and thanks for reporting this issue. image

Lebesgue-zyker commented 11 months ago

Hi @cwmok, the model can be trained normally, thank you! In addition, I would like to ask why the NCC is set to a negative number. When the NCC is a negative number, the NCC value has not converged well.

cwmok commented 11 months ago

NCC measures the similarity between two images. Therefore, we want to minimize the dissimilarity (maximize the similarity) between two images by minimizing -1 * NCC or 1. - NCC.