iyerkrithika21 / mesh2SSM_2023

Implementation of "Mesh2ssm: From surface meshes to statistical shape models of anatomy".
MIT License
12 stars 0 forks source link

Hyperparameters to reproduce results of pancreas dataset #3

Closed NafieAmrani closed 4 months ago

NafieAmrani commented 7 months ago

Hi,

Thank you for sharing the code base of your interesting work!

I am trying to reproduce the results on the pancreas dataset. I have followed the steps you have shared on this issue to create a good starting template. However, when I run the method using the hyperparameters below, the output of the SP-VAE is very similar to the template, with a few variations.

python train_geodesic.py --exp_name pancreas_run --dataset pancreas_mesh2ssm --batch_size 10 --epochs 1000 --lr 0.01 --vae_lr 0.0009 --emb_dims 128 --nf 128 --data_directory /data/pancreas_mesh2ssm/ --dropout=0.5 --latent_dim=128

Here are some results:

Input Shape Reconstruction of M-AE Output of IM-NET Output of SP-VAE
Pancreas 421 image image image image

The learned template looks as follows:

Starting template Template_400 Template_1200 Final template
image image image image

Here are two samples generated using the SP-VAE:

Samples View 1 (side) View 2 (bottom)
1 image image
2 image image

I have tried to look into the hyperparameters I used. However, I couldn't figure out where I might have made a mistake that caused the SP-VAE to output similar results. I hypothesise that the IM-NET is not doing what it is supposed to do with the current parameters.

Any help or feedback from the authors is very appreciated! Thanks in advance for the help!

Kind regards, Nafie

iyerkrithika21 commented 7 months ago

Hello, Thanks for reaching out.

A few questions and suggestions:

  1. Are you using the geodesic neighbors for the mesh AE? You can modify the code to make the calculation faster for generating the pickle file

def geodescis(pos, face,k, max_gdist):

pos= torch.Tensor(pos)
face = torch.Tensor(face)
dist = -1*geodesic_distance(pos,face.t(),norm=False,num_workers=-1, max_distance = max_gdist)
idx = dist.topk(k=k,dim=-1)[1]
return idx


2. I would suggest trying to train the mesh branch first without the SP-VAE to make sure everything is okay before including it in the end to end training. I have noticed that the VAE is very tricky to train and small bugs can cause a lot of instability. Moreover, for reconstruction and performance analysis, it is better to use the mesh branch correspondences. 

3. Make sure that when you are generating the mesh branch correspondences, the `idx` is passed correctly. 
I noticed a small bug in my code [here](https://github.com/iyerkrithika21/mesh2SSM_2023/blob/adb403c8ba2f299930d7e89c198e645796081475/train_geodesic.py#L179) as the model is not passed `idx` with the vertices. 
iyerkrithika21 commented 7 months ago

Could you also share your other hyperparameters like: k, mse_weight and the number of particles in your template point cloud?

iyerkrithika21 commented 7 months ago
exp_name='new_pancreas_learned_template_256', batch_size=10, test_batch_size=10, epochs=1000, use_sgd=False, lr=0.01, 

vae_lr=0.001, momentum=0.9, no_cuda=False, seed=42, eval=False, dropout=0.5, emb_dims=64, nf=16, k=15, 

data_directory='pancreas/', model_type='autoencoder', mse_weight=0.01, template='mean_256_template', extention='.ply', 

gpuid=0, vae_mse_weight=10, latent_dim=64

First template update at 400 and update frequency 200 You could try these as well to see if it helps

NafieAmrani commented 6 months ago

Hi,

Thank you for your answer and sorry for my late response.

Are you using the geodesic neighbors for the mesh AE? You can modify the code to make the calculation faster for generating the pickle file

Yes, I did use the geodesic neighbors and used the provided code to calculate them. I didn't change anything in that step.

I would suggest trying to train the mesh branch first without the SP-VAE to make sure everything is okay before including it in the end to end training. I have noticed that the VAE is very tricky to train and small bugs can cause a lot of instability. Moreover, for reconstruction and performance analysis, it is better to use the mesh branch correspondences.

I tried this but I, unfortunately, couldn't find the sweet spot to avoid the collapse of the VAE.

Make sure that when you are generating the mesh branch correspondences, the idx is passed correctly. I noticed a small bug in my code here as the model is not passed idx with the vertices.

I have updated this, thanks for checking the code.

Could you also share your other hyperparameters like: k, mse_weight and the number of particles in your template point cloud?

I used k =10, mse_weight=0.01 and the template has 256 particles. Here are all the hyperparameters I used to get the results in my original question:

python train_geodesic.py --exp_name pancreas_run --dataset pancreas_mesh2ssm --batch_size 10 --test_batch_size 10 --epochs 1000 --use_sgd false  --momentum 0.9  --dropout 0.5 --lr 0.01 --vae_lr 0.0009 --emb_dims 128 --nf 128 --k 10 --data_directory /data/pancreas_mesh2ssm/ --model_type 'autoencoder'  --mse_weight 0.01  --template "template" --extention ".ply"  --vae_mse_weight 10 --latent_dim 128

I used the hyperparameters you shared but got a similar outcome. See screenshots below:

Input Shape Reconstruction of M-AE Output of IM-NET Output of SP-VAE
Pancreas 421 image image image image

Would you be able by any chance to share the weights of the network trained?

Thank you again for your help!

Kind regards, Nafie

iyerkrithika21 commented 6 months ago
exp_name='new_pancreas_learned_template_256', batch_size=10, test_batch_size=10, epochs=1000, use_sgd=False, lr=0.01, 

vae_lr=0.001, momentum=0.9, no_cuda=False, seed=42, eval=False, dropout=0.5, emb_dims=64, nf=16, k=15, 

data_directory='pancreas/', model_type='autoencoder', mse_weight=0.01, template='mean_256_template', extention='.ply', 

gpuid=0, vae_mse_weight=10, latent_dim=64

First template update at 400 and update frequency 200 You could try these as well to see if it helps

Have you tried these hyper-parameters? I am really supersized with your M-AE output. I have never observed this output to be so bad. The M-AE usually learns to reconstruct the mesh vertices fairly early on.

I will try to find to share the network weights.

I tried this but I, unfortunately, couldn't find the sweet spot to avoid the collapse of the VAE.

Yes, this is a drawback of relying on VAE when using a small dataset (~250 samples) and we are working on methods to avoid this problem in the future.