Closed jillianlee closed 4 years ago
Hi @jillianlee ,
Thanks for your interest here.
For 2D segmentation task, it should be just slightly different from 3D segmentation at network
dims and transform
shapes.
Could you please paste your test program and let me check some details?
Thanks.
Hi @Nic-Ma
I assumed as such, and I'm not getting any errors, it just seems like a vanishing gradient problem but the usual solutions are not helping. I've tried changing activation functions in the UNet, loss functions, optimizers and nothing has helped.
class ConvertLabel(MapTransform):
"""
Convert labels to multi channels based on brats classes:
label 1 is the peritumoral edema
label 2 is the GD-enhancing tumor
label 3 is the necrotic and non-enhancing tumor core
The possible classes are TC (Tumor core), WC (Whole tumor)*** we want WC because flair shows whole tumor
and ET (Enhancing tumor).
"""
def __call__(self, data):
d = dict(data)
for key in self.keys:
# merge labels 1, 2 and 3 to construct WC
result = np.logical_or(np.logical_or(d[key] == 2, d[key] == 3), d[key] == 1)
d[key] = result.astype(np.float32)
return d
monai.config.print_config()
set_determinism(seed=0)
tempdir = tempfile.mkdtemp()
text_t1 = open(r'C:\Users\iBest\Documents\Jillian\Documents\BO-Aug-master\dataset\BraTS\filename_flair.txt', 'r')
train_images = text_t1.read().split('\n')
text_segs = open(r'C:\Users\iBest\Documents\Jillian\Documents\BO-Aug-master\dataset\BraTS\filename_seg.txt', 'r')
train_labels = text_segs.read().split('\n')
data_dicts = [{'image': image_name, 'label': label_name}
for image_name, label_name in zip(train_images, train_labels)]
transforms = Compose([
LoadNiftid(keys=['image', 'label']),
AddChanneld(keys=['image', 'label']),
ConvertLabel(keys='label'),
#Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 2.), mode=('bilinear', 'nearest')),
Orientationd(keys=['image', 'label'], axcodes='RAS'),
Resized(keys=['image', 'label'], spatial_size = (160, 160, 72)),
RandSpatialCropSamplesd(keys=['image', 'label'], roi_size=[160, 160, 1], num_samples = 72, random_center = True, random_size = False),
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
SqueezeDimd(keys=['image', 'label'], dim=-1),
#ScaleIntensityd(keys='image', minv=0.0, maxv=1.0),
RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
ToNumpyd(keys=['image', 'label'])
])
transforms_val = Compose([
LoadNiftid(keys=['image', 'label']),
AddChanneld(keys=['image', 'label']),
ConvertLabel(keys='label'),
#Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 2.), mode=('bilinear', 'nearest')),
Orientationd(keys=['image', 'label'], axcodes='RAS'),
Resized(keys=['image', 'label'], spatial_size = (160, 160, 72)),
RandSpatialCropSamplesd(keys=['image', 'label'], roi_size=[160, 160, 1], num_samples = 72, random_center = True, random_size = False),
SqueezeDimd(keys=['image', 'label'], dim=-1),
#ScaleIntensityd(keys='image', minv=0.0, maxv=1.0),
ToNumpyd(keys=['image', 'label'])
])
train_files, val_files, test_files = data_dicts[0:171], data_dicts[171:228], data_dicts[228:285]
train_ds = monai.data.ArrayDataset(train_files, img_transform = transforms)
train_load = monai.data.DataLoader(train_ds)
val_ds = monai.data.ArrayDataset(val_files, img_transform = transforms_val)
val_load = monai.data.DataLoader(val_ds)
test_ds = monai.data.ArrayDataset(test_files, img_transform = transforms_val)
test_load = monai.data.DataLoader(test_ds)
flat_test = [item for sublist in test_load.dataset for item in sublist]
test_data = [sub['image'] for sub in flat_test]
test_data = [np.squeeze(im, 0) for im in test_data]
test_labels = [sub['label'] for sub in flat_test]
test_labels = [np.squeeze(lab, 0) for lab in test_labels]
flat_train = [item for sublist in train_load.dataset for item in sublist]
train_images = [sub['image'] for sub in flat_train]
train_images = [np.squeeze(im, 0) for im in train_images]
train_labels = [sub['label'] for sub in flat_train]
train_labels = [np.squeeze(lab, 0) for lab in train_labels]
flat_val = [item for sublist in val_load.dataset for item in sublist]
val_images = [sub['image'] for sub in flat_val]
val_images = [np.squeeze(im, 0) for im in val_images]
val_labels = [sub['label'] for sub in flat_val]
val_labels = [np.squeeze(lab, 0) for lab in val_labels]
dict_train = [{'image': image_data, 'label': label_data}
for image_data, label_data in zip(train_images, train_labels)]
dict_val = [{'image': image_data, 'label': label_data}
for image_data, label_data in zip(val_images, val_labels)]
dict_test = [{'image': image_data, 'label': label_data}
for image_data, label_data in zip(test_data, test_labels)]
transforms = Compose([
AddChanneld(keys=['image', 'label']),
ToNumpyd(keys=['image', 'label'])
])
train_data = monai.data.ArrayDataset(dict_train, img_transform = transforms)
train_loader = monai.data.DataLoader(train_data, batch_size=3, shuffle=True, num_workers=0, multiprocessing_context = None)
val_data = monai.data.ArrayDataset(dict_val, img_transform = transforms)
val_loader = monai.data.DataLoader(val_data, batch_size=1, shuffle=True, num_workers=0, multiprocessing_context = None)
test_data = monai.data.ArrayDataset(dict_test, img_transform = transforms)
test_loader = monai.data.DataLoader(test_data, batch_size=1, shuffle=True, num_workers=0, multiprocessing_context = None)
device = torch.device('cuda:0')
model = monai.networks.nets.UNet(dimensions=2, in_channels=1, out_channels=1, channels=(16, 32, 64),
strides=(2, 2), act = 'elu').to(device)
loss_function = monai.losses.DiceLoss(to_onehot_y=False, sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, to_onehot_y=False, sigmoid = True, reduction="mean")
num_epochs = 100
# TRAINING
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = list()
metric_values = list()
writer = SummaryWriter()
for epoch in range(num_epochs):
print("-" * 10)
print(f"epoch {epoch + 1}/{20}")
model.train()
epoch_loss = 0
step = 0
for batch_data in train_loader:
step += 1
#inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
inputs, labels = batch_data['image'].to(device), batch_data['label'].to(device)
optimizer.zero_grad()
outputs = model(inputs.type(torch.cuda.FloatTensor))
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_len = len(train_data) // train_loader.batch_size
print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
#writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
if (epoch + 1) % val_interval == 0:
model.eval()
with torch.no_grad():
metric_sum = 0.0
metric_count = 0
val_images = None
val_labels = None
val_outputs = None
for val_data in val_loader:
#val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
val_images, val_labels = val_data['image'].to(device), val_data['label'].to(device)
roi_size = (160, 160)
sw_batch_size = 4
#val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
val_outputs = model(val_images)
value = dice_metric(y_pred=val_outputs, y=val_labels)
metric_count += len(value)
metric_sum += value.item() * len(value)
metric = metric_sum / metric_count
metric_values.append(metric)
if metric > best_metric:
best_metric = metric
best_metric_epoch = epoch + 1
torch.save(model.state_dict(), "best_metric_model.pth")
print("saved new best metric model")
print(
"current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
epoch + 1, metric, best_metric, best_metric_epoch
)
)
#writer.add_scalar("val_mean_dice", metric, epoch + 1)
#plot the last model output as GIF image in TensorBoard with the corresponding image and label
#shutil.rmtree(tempdir)
The network parameters, optimizer, and loss function seem to work fine with the 3D data fed from ArrayDataset
.
When I perform sampling, the images seem to look fine, but the segmentation masks come out black and the peak DSC is around 0.2.
I wondered if there was an issue with array management/sampling so I tried running a 2D UNet with the RandSpatialCropSamplesd
transform, but this does cause an error with batch size in the validation set. The validation batch size of 1, combined with RandSpatialCropSamplesd
creates batch size of 1xnum_samples
which throws an error in the validation portion.
The 2D numpy array sampling portion is something I need to do for my own data augmentation experiments so unfortunately I cannot omit it.
Any advice would be appreciated :) Thanks Jillian
Hi @jillianlee ,
2 points I found here may be helpful:
monai.data.Dataset
or monai.data.CacheDataset
instead of ArrayDataset
.1, 2, 4
, you need to double confirm it.Thanks.
Hi @Nic-Ma
Oh right, my apologies. Previously I had used array data with the ArrayDataset
, and I tried switching to dict type data to debug.
I switched to Dataset
and no improvement.
I also just tried 1,2, and 4 for BraTS and no repair. The masks still all come out black
Essentially this entire code works perfectly fine with the 3D data, but as soon as its sliced into 2D images, it breaks. The 3D U-Net obtains a reasonable DSC around high 0.7ish.
The 2D code is doing something ~ but not much. Below are the graphs for the accuracy and DSC over epochs. It just gets stuck and keeps predicting black masks, even with GeneralizedDiceLoss and alternative activation functions to compensate for the class imbalance. I also have tried the sliding window inference predictor and it had a similar output graph but slightly smoother.
I'm also getting same behavior on the 2D HighResNet and VNet.
From what I've read on other forums its a common problem for class imbalanced problems. The suggested fixes include changing the activation functions to tanh or elu, which I have tried. Or changing the loss functions to Generalized Dice Loss, Tversky Loss, or Masked Dice loss. These adjustments haven't helped either.
I am working on something similar. Could you share the whole code for ConvertLabel.
I think this also would be a nice addition to MONAI.
Hi @kissievendor
class ConvertLabel(MapTransform):
"""
Convert labels to multi channels based on brats classes:
label 1 is the peritumoral edema
label 2 is the GD-enhancing tumor
label 3 is the necrotic and non-enhancing tumor core
The possible classes are TC (Tumor core), WC (Whole tumor)*** we want WC because flair shows whole tumor
and ET (Enhancing tumor).
"""
def __call__(self, data):
d = dict(data)
for key in self.keys:
# merge labels 1, 2 and 3 to construct WC
result = np.logical_or(np.logical_or(d[key] == 2, d[key] == 3), d[key] == 1)
d[key] = result.astype(np.float32)
return d
Here is the whole code for ConvertLabel. Its modified from the BraTS example that they had previously posted in their examples folder (I believe its been taken down now?). I set it up this way to output only 1 binary label for 1 output channel. This combination of labels is mentioned in the BraTS reference paper. It is the combination that is visible from the FLAIR data.
Hi @kissievendor
class ConvertLabel(MapTransform): """ Convert labels to multi channels based on brats classes: label 1 is the peritumoral edema label 2 is the GD-enhancing tumor label 3 is the necrotic and non-enhancing tumor core The possible classes are TC (Tumor core), WC (Whole tumor)*** we want WC because flair shows whole tumor and ET (Enhancing tumor). """ def __call__(self, data): d = dict(data) for key in self.keys: # merge labels 1, 2 and 3 to construct WC result = np.logical_or(np.logical_or(d[key] == 2, d[key] == 3), d[key] == 1) d[key] = result.astype(np.float32) return d
Here is the whole code for ConvertLabel. Its modified from the BraTS example that they had previously posted in their examples folder (I believe its been taken down now?). I set it up this way to output only 1 binary label for 1 output channel. This combination of labels is mentioned in the BraTS reference paper. It is the combination that is visible from the FLAIR data.
Thanks! Where does the MapTransform come from? It gives an error for me.
Have you had more luck with your training / loss?
Oh you have to import MapTransform from monai transforms. Here are all my imports (although some are unused):
import monai
from monai.transforms import (Compose, LoadNiftid, LoadNifti, AddChanneld, AddChannel, ScaleIntensityRanged, CropForegroundd,
Resized, ToNumpyd, SqueezeDimd, RandSpatialCropd, RandSpatialCropSamplesd, MapTransform, Spacingd, Orientationd, ScaleIntensityd, ToNumpyd, ToTensor, RandFlipd, RandScaleIntensityd, ToNumpy, RandShiftIntensityd)
No, I've had no luck on 2D. Works fine on 3D though with all the same network and loss settings, so I'm not sure what's going on.
Oh you have to import MapTransform from monai transforms. Here are all my imports (although some are unused):
import monai from monai.transforms import (Compose, LoadNiftid, LoadNifti, AddChanneld, AddChannel, ScaleIntensityRanged, CropForegroundd, Resized, ToNumpyd, SqueezeDimd, RandSpatialCropd, RandSpatialCropSamplesd, MapTransform, Spacingd, Orientationd, ScaleIntensityd, ToNumpyd, ToTensor, RandFlipd, RandScaleIntensityd, ToNumpy, RandShiftIntensityd)
No, I've had no luck on 2D. Works fine on 3D though with all the same network and loss settings, so I'm not sure what's going on.
Oh ofcourse! Thanks. I am working with brain data and very small annotations. 3D could mean more info from other slices? Try training on different orientations (quickNAT) good luck!
@kissievendor
So I made some mistakes in my code, and it is performing better now. Not the greatest, but still better!
Try increasing the channels to like (64, 128, 256) or higher.
I cropped out some of the black space and centered more around the brain.
I had made an error in my figure plots in the testing loop, so they were only outputting black. The new code is:
plt.figure("check", (18, 6))
plt.subplot(1, 3, 1)
plt.title("image")
plt.imshow(test_data['image'][0, 0, :, :], cmap="gray")
plt.subplot(1, 3, 2)
plt.title("label")
plt.imshow(test_data['label'][0, 0, :, :])
plt.subplot(1, 3, 3)
plt.title("output")
plt.imshow((val_outputs).detach().cpu()[0, 0, :, :])
plt.show()
This is just inside the model evaluation loop so you can see all of your outputs.
Try Tversky Loss with alpha = 0.3 and beta =0.7
My accuracy is around 0.50 now with Tversky Loss, I'm going to play around with some other ones and hope for more improvement.
Best of luck :)
@jillianlee
Hey, I am facing the issue the model not training the loss is stuck.
Is the model parameter and config are correct
loss_function = DiceLoss(to_onehot_y=False, sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-3)
dice_metric = DiceMetric(include_background=True, reduction="mean")
post_transforms = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
model = SegResNet(
spatial_dims=2,
blocks_down=[1, 2, 2, 4],
blocks_up=[1, 1, 1],
init_filters=16,
in_channels=3,
out_channels=1,
dropout_prob=0.2,
use_conv_final=True,
act= "RELU",
norm=Norm.BATCH,
).to(device)
What wrong is done ?
Thank's
Has there been any fix on this? Thank you.
Hi There,
Do you plan on posting examples or workbooks for any 2D segmentation problems?
I'm having some issues trying to get predictions to work with 2D input data. The same data given in 3D to the same network works just fine, but when using the 2D data, all my masks come out black. I've tried changing quite a few of the parameters and optimizer settings but nothings helped. I think an example may be helpful if you have time for that!
Thanks Jillian