ntustison / ANTsXNetTraining

Apache License 2.0
1 stars 1 forks source link

Super Resolution Pretrained model (allen_sr_weights.h5) #1

Closed CarloAmodeo21 closed 1 year ago

CarloAmodeo21 commented 1 year ago

Hi everyone,

Thank you for your effort in uploading all of this. This is very useful. I have a quick question regarding the train_sr_model.py script, as it seems to be pointing to a pretrained networks (allen_sr_weights.h5) which I cannot find on the repo. Is this .h5 publicly available? Also, what is the dimensions of the files you use to train? It seems to me that it is 256x256x3, is that correct? Are the nifti files containing only 3 slices?

Best, Carlo

ntustison commented 1 year ago

There are a couple of “allen” weights I haven’t uploaded yet. I plan to upload them soon.

CarloAmodeo21 commented 1 year ago

Awesome, thank you. Would you mind sharing some tips on how to run the training?

I loaded the architecture (and the weights of the pretrained model) that is available through the antspynet.get_pretrained_network("mriSuperResolution"), and have uploaded some of my low resolution and high resolution pairs of images.

Everything goes smooth until when I get this error:

Graph execution error: Detected at node 'combined_loss_fixed/model_3/lambda/sub_1' processed = Lambda(lambda x: (x - vgg_mean) / vgg_std)(vgg_input_image) Node: 'combined_loss_fixed/model_3/lambda/sub_1' Incompatible shapes: [2,128,96,32] vs. [3]

Where 2 is my batch size, 128x96x32 is my HighResolution nifti image size, but I really do not understand what that 3 is for.

If you had any suggestions, that would be very appreciated.

ntustison commented 1 year ago

Why would you load "mriSuperResolution"?

CarloAmodeo21 commented 1 year ago

Because it is the only pretrained network I could find on the antspynet repo. So that I could further fine tune it on my data. However, the problem I am facing is related to the loss function (which was pretty much clear from the error message). When I use another loss function, say loss_total = tf.keras.losses.MeanSquaredError(), the training runs smooth. I will keep trying debugging the issue with the previous loss. Again, if you happen to have any suggestions that would be awesome. Thanks

ntustison commented 1 year ago

Thanks. So it's just puzzling to me that, in the context of this thread involving super-resolution of 2-D histology RGB images, that you would try to load weights corresponding to 3-D scalar MR images.

I don't have any suggestions other than what's in the training script. In fact, this repo is more of a personal storage place for my own training scripts as they don't really fit within ANTsXNet proper. And especially with any custom loss functions I use, you're pretty much on your own for figuring out how to apply them to your data.