Project-MONAI / tutorials

MONAI Tutorials
https://monai.io/started.html
Apache License 2.0
1.85k stars 683 forks source link

During training, the verification set index is different from that during prediction #897

Closed BelieferQAQ closed 2 years ago

BelieferQAQ commented 2 years ago

Please use MONAI's Discussions tab For questions relating to MONAI usage, please do not create an issue.

Instead, use MONAI's GitHub Discussions tab. This can be found on the MONAI repository's main page (not the tutorials) next to Issues and Pull Requests along the top.

Hello, when I was training the rib SEG dataset with swinunetr (rib segmentation dataset), may I ask if the dice index of the validation set could reach 0.93 during the training process, but when I predicted the validation set, it was only about 0.7. The following is my training program, model loader and prediction program.

1.def validation(val_loader,epoch,max_epochs): model.eval() start_time = time.time() val_loss = 0 with torch.no_grad(): for step, batch in enumerate(val_loader): val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda())

        val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, model)

        val_loss_batch = loss_function(val_outputs,val_inputs)
        val_loss += val_loss_batch.item()

        val_labels_list = decollate_batch(val_labels)
        val_labels_convert = [
            post_label(val_label_tensor) for val_label_tensor in val_labels_list
        ]
        val_outputs_list = decollate_batch(val_outputs)
        val_output_convert = [
            post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list
        ]

        dice_metric(y_pred=val_output_convert, y=val_labels_convert)

    val_loss = val_loss/(step+1)
    mean_dice_val = dice_metric.aggregate().item()
    dice_metric.reset()
    print("Val {}/{} {}/{}".format(epoch,max_epochs, step, len(val_loader)),
          "loss: {:.4f}".format(val_loss),
          "dice",mean_dice_val,
          "time {:.2f}s".format(time.time() - start_time)
         )
return val_loss

def train(train_loader,scheduler, epoch, max_epochs): model.train() epoch_loss = 0 epoch_sum_loss = 0 start_time = time.time() for step, batch in enumerate(train_loader): x, y = (batch["image"].cuda(), batch["label"].cuda()) logit_map = model(x) loss = loss_function(logit_map, y) loss.backward() epoch_sum_loss += loss.item() optimizer.step() optimizer.zero_grad() epoch_loss = epoch_sum_loss/(step+1) lr = optimizer.param_groups[0]['lr'] scheduler.step() print( "LR: {:.7f}".format(lr), "Epoch {}/{} {}/{}".format(epoch, max_epochs, step, len(train_loader)), "loss: {:.4f}".format(epoch_loss), "time {:.2f}s".format(time.time() - start_time)) return epoch_loss

post_label = AsDiscrete(to_onehot=2) post_pred = AsDiscrete(argmax=True, to_onehot=2) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) loss_val_max = 100 epoch_loss_values = [] metric_values = [] warmup_epochs = 5 max_epochs = 300 start_epoch = 0 scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=warmup_epochs, max_epochs=max_epochs) for epoch in range(start_epoch, max_epochs): epoch_loss = train(train_loader,scheduler, epoch, max_epochs) epoch_loss_values.append(epoch_loss) if (epoch+1) % 5 == 0 or epoch == 0: val_loss = validation(val_loader,epoch, max_epochs) metric_values.append(val_loss) if val_loss < loss_val_max: loss_val_max = val_loss torch.save(model.state_dict(), os.path.join(root_dir, "best_metric_model.pth")) print( "Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format( loss_val_max, val_loss ) ) else: print( "Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format( loss_val_max, val_loss ) )

image

2.model = SwinUNETR(img_size=96, in_channels=1, out_channels=2, feature_size=48, drop_rate=0.0, attn_drop_rate=0.0, dropout_path_rate=0.0, use_checkpoint=True, )

pretrained_pth = "/data/jupyter/wd/swinUnetr_RibFrac/checkpoints/best_metric_model.pth"

model.load_state_dict(torch.load(path))

model_dict = torch.load(pretrained_pth)["state_dict"]

model.load_state_dict(model_dict)

weights = torch.load(pretrained_pth) weights_dict = {} for k, v in weights.items(): new_k = k.replace('module.', '') if 'module' in k else k weights_dict[new_k] = v model.load_state_dict(weights_dict)

model = torch.nn.DataParallel(model) model = model.cuda() model.eval()

print("Using SwinUNETR weights !")

3.with torch.no_grad(): dice_list_case = [] for i, batch in enumerate(test_loader): val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda()) original_affine = batch['label_metadict']['affine'][0].numpy() , _, h, w, d = val_labels.shape target_shape = (h, w, d) print("The shape:{}".format(target_shape)) img_name = batch['image_meta_dict']['filename_or_obj'][0].split('/')[-1] print("Inference on case {}".format(img_name)) val_outputs = sliding_window_inference(val_inputs, (96,96,96), 4, model, overlap=0.5, mode="gaussian" ) val_outputs = torch.softmax(val_outputs, 1).cpu().numpy() val_outputs = np.argmax(val_outputs, axis=1).astype(np.uint8)[0] val_labels = val_labels.cpu().numpy()[0, 0, :, :, :] val_outputs = resample_3d(val_outputs, target_shape)

    organ_Dice = dice(val_outputs == 1, val_labels == 1)
    dice_list_case.append(organ_Dice)

    print("the dice:{}".format(organ_Dice))

    nib.save(nib.Nifti1Image(val_outputs.astype(np.uint8), original_affine),
             os.path.join(output_directory, img_name))
print("The nuber of case: {}".format(len(dice_list_case)))
print("Overall Mean Dice: {}".format(np.mean(dice_list_case)))

image

KumoLiu commented 2 years ago

Hi @BelieferQAQ, In theory, validating and testing validation dataset should yield the same result, but I'm not sure if you treated them same. Such as if you both using sliding window inference or the same post-processing. The code you posted is not good formatted, making it difficult for me to extract useful information from it. Could you re-post it, please? Thanks!

BelieferQAQ commented 2 years ago
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
import shutil
import tempfile

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
    AsDiscrete,
    AddChanneld,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
    ToTensord,
)

from monai.config import print_config
from monai.metrics import DiceMetric
from monai.networks.nets import SwinUNETR
from optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from monai.data import (
    DataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
)

import time
import torch

print_config()

root_dir = './checkpoints'
if not os.path.exists(root_dir):
    os.makedirs(root_dir)
print(root_dir)

num_samples = 2

train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        AddChanneld(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-200,
            a_max=1000,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=(96, 96, 96),
            pos=1,
            neg=1,
            num_samples=num_samples,
            image_key="image",
            image_threshold=0,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[0],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[1],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[2],
            prob=0.10,
        ),
        RandRotate90d(
            keys=["image", "label"],
            prob=0.10,
            max_k=3,
        ),
        RandShiftIntensityd(
            keys=["image"],
            offsets=0.10,
            prob=0.50,
        ),
        ToTensord(keys=["image", "label"]),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        AddChanneld(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        ScaleIntensityRanged(
            keys=["image"], a_min=-200, a_max=1000, b_min=0.0, b_max=1.0, clip=True
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        ToTensord(keys=["image", "label"]),
    ]
)

data_dir = "/data2/dataset/wdDATASET/RibSegFrac/"
split_JSON = "dataset.json"

datasets = data_dir + split_JSON
datalist = load_decathlon_datalist(datasets, True, "training")
val_files = load_decathlon_datalist(datasets, True, "validation")
train_ds = CacheDataset(
    data=datalist,
    transform=train_transforms,
    cache_num=320,
    cache_rate=1.0,
    num_workers=8,
)
train_loader = DataLoader(
    train_ds, batch_size=2, shuffle=True, num_workers=8, pin_memory=True
)
val_ds = CacheDataset(
    data=val_files, transform=val_transforms, cache_num=50, cache_rate=1.0, num_workers=4
)
val_loader = DataLoader(
    val_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True
)

model = SwinUNETR(
    img_size=(96, 96, 96),
    in_channels=1,
    out_channels=2,
    feature_size=48,
    use_checkpoint=True,
)

path1 = "/data/jupyter/wd/research-contributions-main/RibFrac_transformer/pretrained_checkpoint/model_swinvit.pt"
weight = torch.load(path1)
model.load_from(weights=weight)

model = torch.nn.DataParallel(model)
model = model.cuda()

print("Using pretrained self-supervied Swin UNETR backbone weights !")

model = SwinUNETR(
    img_size=(96, 96, 96),
    in_channels=1,
    out_channels=2,
    feature_size=48,
    use_checkpoint=True,
)

path1 = "/data/jupyter/wd/research-contributions-main/RibFrac_transformer/pretrained_checkpoint/model_swinvit.pt"
weight = torch.load(path1)
model.load_from(weights=weight)

model = torch.nn.DataParallel(model)
model = model.cuda()

print("Using pretrained self-supervied Swin UNETR backbone weights !")

def validation(val_loader,epoch,max_epochs):
    model.eval()
    start_time = time.time()
    val_loss = 0
    with torch.no_grad():
        for step, batch in enumerate(val_loader):
            val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda())

            val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, model)

            val_loss_batch = loss_function(val_outputs,val_inputs)
            val_loss += val_loss_batch.item()

            val_labels_list = decollate_batch(val_labels)
            val_labels_convert = [
                post_label(val_label_tensor) for val_label_tensor in val_labels_list
            ]
            val_outputs_list = decollate_batch(val_outputs)
            val_output_convert = [
                post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list
            ]

            dice_metric(y_pred=val_output_convert, y=val_labels_convert)

        val_loss = val_loss/(step+1)
        mean_dice_val = dice_metric.aggregate().item()
        dice_metric.reset()
        print("Val {}/{} {}/{}".format(epoch,max_epochs, step, len(val_loader)),
              "loss: {:.4f}".format(val_loss),
              "dice",mean_dice_val,
              "time {:.2f}s".format(time.time() - start_time)
             )
    return val_loss

def train(train_loader,scheduler, epoch, max_epochs):
    model.train()
    epoch_loss = 0
    epoch_sum_loss = 0
    start_time = time.time()
    for step, batch in enumerate(train_loader):
        x, y = (batch["image"].cuda(), batch["label"].cuda())
        logit_map = model(x)
        loss = loss_function(logit_map, y)
        loss.backward()
        epoch_sum_loss += loss.item()
        optimizer.step()
        optimizer.zero_grad()
    epoch_loss = epoch_sum_loss/(step+1)
    lr = optimizer.param_groups[0]['lr']
    scheduler.step()
    print(  "LR: {:.7f}".format(lr),
            "Epoch {}/{} {}/{}".format(epoch, max_epochs, step, len(train_loader)),
            "loss: {:.4f}".format(epoch_loss),
            "time {:.2f}s".format(time.time() - start_time))
    return epoch_loss

post_label = AsDiscrete(to_onehot=2)
post_pred = AsDiscrete(argmax=True, to_onehot=2)
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
loss_val_max = 100
epoch_loss_values = []
metric_values = []
warmup_epochs = 5
max_epochs = 300
start_epoch = 0
scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=warmup_epochs, max_epochs=max_epochs)
for epoch in range(start_epoch, max_epochs):
    epoch_loss = train(train_loader,scheduler, epoch, max_epochs)
    epoch_loss_values.append(epoch_loss)
    if (epoch+1) % 5 == 0 or epoch == 0:
        val_loss = validation(val_loader,epoch, max_epochs)
        metric_values.append(val_loss)
        if val_loss < loss_val_max:
            loss_val_max = val_loss
            torch.save(model.state_dict(), os.path.join(root_dir, "best_metric_model.pth"))
            print(
                    "Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(
                        loss_val_max, val_loss
                    )
                )
        else:
            print(
                    "Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(
                        loss_val_max, val_loss
                    )
                )
BelieferQAQ commented 2 years ago
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
import shutil
import tempfile

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
    AsDiscrete,
    AddChanneld,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
    ToTensord,
)

from monai.config import print_config
from monai.metrics import DiceMetric
from monai.networks.nets import SwinUNETR
from monai.data import (
    DataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
)

import time
import torch

from monai import transforms, data
from monai.data import load_decathlon_datalist
import scipy.ndimage as ndimage
import nibabel as nib

print_config()

test_transform = transforms.Compose(
    [
        transforms.LoadImaged(keys=["image", "label"]),
        transforms.AddChanneld(keys=["image", "label"]),
        transforms.Spacingd(keys="image",
                            pixdim=(1.5, 1.5, 2),
                            mode="bilinear"),
        transforms.ScaleIntensityRanged(keys=["image"],
                                        a_min=-200,
                                        a_max=1000,
                                        b_min=0.0,
                                        b_max=1.0,
                                        clip=True),
        transforms.ToTensord(keys=["image", "label"]),
    ]
)

distributed = False
datalist_json = '/data2/dataset/wdDATASET/RibSegFrac/dataset.json'
data_dir = '/data2/dataset/wdDATASET/RibSegFrac'
test_files = load_decathlon_datalist(datalist_json,
                                    True,
                                    "validation",
                                    base_dir=data_dir)
test_ds = data.Dataset(data=test_files, transform=test_transform)
test_sampler = Sampler(test_ds, shuffle=False) if distributed else None
test_loader = data.DataLoader(test_ds,
                             batch_size=1,
                             shuffle=False,
                             num_workers=4,
                             sampler=test_sampler,
                             pin_memory=True,
                             persistent_workers=False)

model = SwinUNETR(img_size=96,
                  in_channels=1,
                  out_channels=2,
                  feature_size=48,
                  drop_rate=0.0,
                  attn_drop_rate=0.0,
                  dropout_path_rate=0.0,
                  use_checkpoint=True,
                  )

pretrained_pth = "/data/jupyter/wd/swinUnetr_RibFrac/checkpoints/best_metric_model.pth"
# model.load_state_dict(torch.load(path))
# model_dict = torch.load(pretrained_pth)["state_dict"]
# model.load_state_dict(model_dict)

weights = torch.load(pretrained_pth)
weights_dict = {}
for k, v in weights.items():
    new_k = k.replace('module.', '') if 'module' in k else k
    weights_dict[new_k] = v
model.load_state_dict(weights_dict)

model = torch.nn.DataParallel(model)
model = model.cuda()
model.eval()

print("Using SwinUNETR weights !")

test_mode = True
output_directory = './outputs/'+'test'
if not os.path.exists(output_directory):
    os.makedirs(output_directory)

def resample_3d(img, target_size):
    imx, imy, imz = img.shape
    tx, ty, tz = target_size
    zoom_ratio = ( float(tx) / float(imx), float(ty) / float(imy), float(tz) / float(imz))
    img_resampled = ndimage.zoom( img, zoom_ratio, order=0, prefilter=False)
    return img_resampled

def dice(x, y):
    intersect = np.sum(np.sum(np.sum(x * y)))
    y_sum = np.sum(np.sum(np.sum(y)))
    if y_sum == 0:
        return 0.0
    x_sum = np.sum(np.sum(np.sum(x)))
    return 2 * intersect / (x_sum + y_sum)

with torch.no_grad():
    dice_list_case = []
    for i, batch in enumerate(test_loader):
        val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda())
        original_affine = batch['label_meta_dict']['affine'][0].numpy()
        _, _, h, w, d = val_labels.shape
        target_shape = (h, w, d)
        print("The shape:{}".format(target_shape))
        img_name = batch['image_meta_dict']['filename_or_obj'][0].split('/')[-1]
        print("Inference on case {}".format(img_name))
        val_outputs = sliding_window_inference(val_inputs,
                                               (96,96,96),
                                               4,
                                               model,
                                               overlap=0.5,
                                               mode="gaussian"
                                               )
        val_outputs = torch.softmax(val_outputs, 1).cpu().numpy()
        val_outputs = np.argmax(val_outputs, axis=1).astype(np.uint8)[0]
        val_labels = val_labels.cpu().numpy()[0, 0, :, :, :]
        val_outputs = resample_3d(val_outputs, target_shape)

        organ_Dice = dice(val_outputs == 1, val_labels == 1)
        dice_list_case.append(organ_Dice)

        print("the dice:{}".format(organ_Dice))

        nib.save(nib.Nifti1Image(val_outputs.astype(np.uint8), original_affine),
                 os.path.join(output_directory, img_name))
    print("The nuber of case: {}".format(len(dice_list_case)))
    print("Overall Mean Dice: {}".format(np.mean(dice_list_case)))
BelieferQAQ commented 2 years ago

The first code is the training code, and the second is the prediction code. Thank you very much for your help

KumoLiu commented 2 years ago

Hi @BelieferQAQ, From your code, your test_transform is different from val_transforms, if you want to check the validaiton result just use val_transforms and the same post transform in infer. Thanks!

BelieferQAQ commented 2 years ago

Hi @BelieferQAQ, From your code, your test_transform is different from val_transforms, if you want to check the validaiton result just use val_transforms and the same post transform in infer. Thanks!

Thank you very much for your reply. In addition, I would like to ask whether the part of the code that I saved the prediction results as a nii.gz file and loaded the weights is correct. thank you

KumoLiu commented 2 years ago

Hi @BelieferQAQ, you can simply load weight like this:

model.load_state_dict(torch.load("best_metric_model_classification3d_dict.pth"))

And using SaveImage to save images like:

saver = SaveImage(output_dir="./output", output_ext=".nii.gz", output_postfix="seg")

You can refer to this tutorial for more guidance. Hope it can help you, thanks!

BelieferQAQ commented 1 year ago

Hello, when I use the load_state_dict function to load weights, I encounter the following issue: Error(s) in loading state_dict for SwinUNETR: Missing key(s) in state_dict: "swinViT.patch_embed.proj.weight", "swinViT.patch_embed.proj.bias", "swinViT.layers1.0.blocks.0.norm1.weight", "swinViT.layers1.0.blocks.0.norm1.bias", "swinViT.layers1.0.blocks.0.attn.relative_position_bias_table", "swinViT.layers1.0.blocks.0.attn.relative_position_index", I got the weights from training with multiple graphics cards, did you get the weights from training with a single graphics card?

BelieferQAQ commented 1 year ago

I want to resample the prediction to the original image size, how should I use this function SaveImage

KumoLiu commented 1 year ago

Hi @BelieferQAQ,

Hope it helps, thanks!

BelieferQAQ commented 1 year ago

Hi @BelieferQAQ,

Hope it helps, thanks! thanks your reply。i have some questions: root = "/data1/DataSets/wdDataset/lymphNodeDATA/LNQ2023/train/" saver = SaveImage(output_dir="./test_output_unetr", output_ext=".nii.gz", output_postfix="seg",resample=True,separate_folder=False) with torch.no_grad(): for i, batch in enumerate(val_loader): val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda()) img_name = batch['image_meta_dict']['filename_or_obj'][0].split('/')[-1] print("Inference on case {}".format(img_name)) val_outputs = sliding_window_inference(val_inputs, (96,96,96), 4, model, overlap=0.5, mode="gaussian" ) val_outputs = torch.softmax(val_outputs, 1).cpu().numpy() val_outputs = np.argmax(val_outputs, axis=1).astype(np.uint8)[0] saver(val_outputs,batch['image_meta_dict']) print("saved!") I want to perform inference on the validation set and save the results as NIFTI files with the original shape of the validation samples. However, I encountered the following error: RuntimeError: SaveImage cannot find a suitable writer for test_output_unetr/lnq2023-train-0943-ct.nii_seg.nii.gz. Please install the writer libraries, see also the installation instructions: https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies. The current registered writers for .nii.gz: (<class 'monai.data.image_writer.NibabelWriter'>, <class 'monai.data.image_writer.ITKWriter'>). Traceback (most recent call last): File "/home/hdu/anaconda3/envs/wxh/lib/python3.9/site-packages/monai/transforms/io/array.py", line 405, in call writer_obj.set_metadata(meta_dict=meta_data, **self.meta_kwargs) File "/home/hdu/anaconda3/envs/wxh/lib/python3.9/site-packages/monai/data/image_writer.py", line 564, in set_metadata self.data_obj, self.affine = self.resample_if_needed( File "/home/hdu/anaconda3/envs/wxh/lib/python3.9/site-packages/monai/data/image_writer.py", line 262, in resample_if_needed output_array, target_affine = resampler( File "/home/hdu/anaconda3/envs/wxh/lib/python3.9/site-packages/monai/transforms/spatial/array.py", line 196, in call src_affine = to_affine_nd(spatial_rank, src_affine) File "/home/hdu/anaconda3/envs/wxh/lib/python3.9/site-packages/monai/data/utils.py", line 841, in to_affine_nd raise ValueError(f"affine must have 2 dimensions, got {affine_np.ndim}.") ValueError: affine must have 2 dimensions, got 3.

Traceback (most recent call last): File "/home/hdu/anaconda3/envs/wxh/lib/python3.9/site-packages/monai/transforms/io/array.py", line 405, in call writer_obj.set_metadata(meta_dict=meta_data, **self.meta_kwargs) File "/home/hdu/anaconda3/envs/wxh/lib/python3.9/site-packages/monai/data/image_writer.py", line 413, in set_metadata self.data_obj, self.affine = self.resample_if_needed( File "/home/hdu/anaconda3/envs/wxh/lib/python3.9/site-packages/monai/data/image_writer.py", line 262, in resample_if_needed output_array, target_affine = resampler( File "/home/hdu/anaconda3/envs/wxh/lib/python3.9/site-packages/monai/transforms/spatial/array.py", line 196, in call src_affine = to_affine_nd(spatial_rank, src_affine) File "/home/hdu/anaconda3/envs/wxh/lib/python3.9/site-packages/monai/data/utils.py", line 841, in to_affine_nd raise ValueError(f"affine must have 2 dimensions, got {affine_np.ndim}.") ValueError: affine must have 2 dimensions, got 3.