Open AndrewForresterGit opened 1 month ago
It is not normal. You may need to send me some code so I can help you solve the problem. You can first check the input size, is the input size 256 256 256?
The file input size is 256x256x256. I've only modified slightly the 4_predict.py code in ordre to accept command line arguments for shell script purposes. I also removed the use for the label
variable since it doesn't play a role in predictions. Other then those modifications, the code is the same.
import numpy as np
from light_training.dataloading.dataset import get_train_val_test_loader_from_train, get_kfold_loader
import torch
import torch.nn as nn
from monai.inferers import SlidingWindowInferer
from light_training.evaluation.metric import dice
from light_training.trainer import Trainer
from monai.utils import set_determinism
from light_training.evaluation.metric import dice
set_determinism(123)
import os
from light_training.prediction import Predictor
import argparse
data_dir = "./data/fullres/train"
env = "pytorch"
max_epoch = 1000
batch_size = 2
val_every = 2
num_gpus = 2
device = "cuda:0"
patch_size = [128, 128, 128]
def input_parser():
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--data_directory', dest='data_directory')
parser.add_argument('-t', '--test_directory', dest='test_directory')
parser.add_argument('-m', '--model', dest='model')
parser.add_argument('-l', '--log_directory', dest='log_directory')
return parser.parse_args()
dirs = vars(input_parser())
train_dir = dirs['data_directory']
test_dir = dirs['test_directory']
model_path = dirs['model']
save_path = dirs['log_directory']
class BraTSTrainer(Trainer):
def __init__(self, env_type, max_epochs, batch_size, device="cpu", val_every=1, num_gpus=1, logdir="./logs/", master_i p='localhost', master_port=17750, training_script="train.py"):
super().__init__(env_type, max_epochs, batch_size, device, val_every, num_gpus, logdir, master_ip, master_port, tr aining_script)
self.patch_size = patch_size
self.augmentation = False
def convert_labels(self, labels):
## TC, WT and ET
result = [(labels == 1) | (labels == 3), (labels == 1) | (labels == 3) | (labels == 2), labels == 3]
return torch.cat(result, dim=1).float()
def get_input(self, batch):
image = batch["data"]
# label = batch["seg"]
properties = batch["properties"]
# label = self.convert_labels(label)
label = None
return image, label, properties
def define_model_segmamba(self):
from model_segmamba.segmamba import SegMamba
model = SegMamba(in_chans=6,
out_chans=6,
depths=[2,2,2,2,2,2],
feat_size=[48, 96, 192, 384])
# model_path = "/home/xingzhaohu/dev/jiuding_code/brats23/logs/segmamba/model/final_model_0.9038.pt"
new_sd = self.filte_state_dict(torch.load(model_path, map_location="cpu"))
model.load_state_dict(new_sd)
model.eval()
window_infer = SlidingWindowInferer(roi_size=patch_size,
sw_batch_size=2,
overlap=0.5,
progress=True,
mode="gaussian")
predictor = Predictor(window_infer=window_infer,
mirror_axes=[0,1,2])
# save_path = "./prediction_results/segmamba"
os.makedirs(save_path, exist_ok=True)
return model, predictor, save_path
def validation_step(self, batch):
image, label, properties = self.get_input(batch)
ddim = False
model, predictor, save_path = self.define_model_segmamba()
model_output = predictor.maybe_mirror_and_predict(image, model, device=device)
model_output = predictor.predict_raw_probability(model_output,
properties=properties)
model_output = model_output.argmax(dim=0)[None]
# model_output = self.convert_labels_dim0(model_output)
# label = label[0]
# c = 3
# dices = []
# for i in range(0, c):
# output_i = model_output[i].cpu().numpy()
# label_i = label[i].cpu().numpy()
# d = dice(output_i, label_i)
# dices.append(d)
# print(dices)
model_output = predictor.predict_noncrop_probability(model_output, properties)
predictor.save_to_nii(model_output,
raw_spacing=[1,1,1],
case_name = properties['name'][0],
save_dir=save_path)
return 0
def convert_labels_dim0(self, labels):
## TC, WT and ET
result = [(labels == 1) | (labels == 3), (labels == 1) | (labels == 3) | (labels == 2), labels == 3]
return torch.cat(result, dim=0).float()
def filte_state_dict(self, sd):
if "module" in sd :
sd = sd["module"]
new_sd = {}
for k, v in sd.items():
k = str(k)
new_k = k[7:] if k.startswith("module") else k
new_sd[new_k] = v
del sd
return new_sd
if __name__ == "__main__":
trainer = BraTSTrainer(env_type=env,
max_epochs=max_epoch,
batch_size=batch_size,
device=device,
logdir="",
val_every=val_every,
num_gpus=num_gpus,
master_port=17751,
training_script=__file__)
train_ds, val_ds, test_ds = get_kfold_loader(data_dir=train_dir, test_dir=test_dir,)
trainer.validation_single_gpu(test_ds)
# print(f"result is {v_mean}")
In this section :
You can check the output shape of each inference step.
And you can find which step causes this error.
The predictions have size 256x256x1. However, my input images have dimensions 256x256x256. Is this normal? And if not, how do I tweak the code so has to have same dimension output as my input?