nmwsharp / diffusion-net

Pytorch implementation of DiffusionNet for fast and robust learning on 3D surfaces like meshes or point clouds.
https://arxiv.org/abs/2012.00888
MIT License
398 stars 50 forks source link

registered meshes as input #18

Open gattia opened 2 years ago

gattia commented 2 years ago

Thanks for creating and sharing DiffusonNet! Its been really easy to work with so far, so thank you.

I've only started playing around with diffusion net - Im hoping to use it in general pipelines to analyze bone mesh data. My use cases are: 1. predicting clinical or demographic outcomes from the bone geometry or 2. Generative modeling. Im starting with (1) as it "should" be easier.

To give some background - typically we use PCA applied to a pool of data from multiple participants (100s). The rows of the input into PCA are vectors of the vertex positions vector = [x1, x2, ...xn, y1, y2, ... yn, z1, z2, ... zn] for each participant. After applying PCA we project a new meshes vertices onto the PCA latent space to get a set of scalars that we can use for prediction. This approach works reasonably well, for example we can predict sex with >90% accuracy (tested on hold-out set).

My hope is to replace the PCA model with DiffusionNet to better learn from the surface mesh. However, in my first attempts to predict sex DiffusionNet only predicts one outcome (either only male or only female). I've tried 2 methods of using DiffusionNet:

  1. Stock - the same way as used in the Shrec classification except there are only 2 labels
    • This pretty much seems to produce the same result for each inputted mesh - not just the binary outcome, but the actual predictions are nearly identical
  2. DiffusionNet as feature extractor + dense layers. Output global_mean as vector of length 128 to 512 without activation. Then, input this vector into a series of progressively smaller dense layers that ultimately predict my output.
    • This has a similar effect to (1)

The consistently same results is making me wonder if the structure of my meshes is having a big effect. What I mean by that is, technically every mesh I am inputting has the exact same face ordering, it is only the vertex x/y/z positions that change. This is because to make PCA work, we have to have vertex correspondence. So, this set of meshes was created by non-rigidly registering a template mesh to every other mesh - wherever the template meshes vertices land on the new mesh (triangulated by 3 nearest vertices) is the new x/y/z position for the vertices, but they use the same face connectivity from the template.

Do you think that this structural consistency between all of the meshes could be why diffusion-net is predicting a single outcome for every example?

Thanks in advance!

Anthony

gattia commented 2 years ago

I've used the original (unregistered) meshes and it's now learning something which is awesome. At 200 epochs (and a slower LR) it's still oscillating and hasn't reached accuracies of the PCA-based model but that could/likely will be fixed with hyper-parameter optimization.

Does this logic make sense - the same graph structure (faces) for all meshes will reduce the power of the diffusion network to do global classification?

nmwsharp commented 2 years ago

Hi!

This sounds like a really awesome problem, and one that DiffusionNet should be well-suited for, I'm happy to try to help out.

I'm surprised to hear that DiffusionNet isn't working on the original template meshes; having a consistent face/vertex ordering and connectivity shouldn't be an issue---DiffusionNet mostly cannot "see" these things because of how it is built on diffusion.

Is it possible that something else is going on? Are you using a batch size? And HKS or XYZ features? (or something else?) Also, is the model producing only one output in the training set, testing set, or both?

nmwsharp commented 2 years ago

Also, are your classes balanced? (E.g. roughly 50/50 distribution?) If it's very unbalanced, sometimes networks get trapped in a local minimum where they prefer majority. If so, re-weighting the examples to be roughly equal weight can help. Another useful trick is the label-smoothing loss https://github.com/nmwsharp/diffusion-net/blob/d20210a60570af946a113a0ea7ea00ced71ae727/experiments/classification_shrec11/classification_shrec11.py#L146

gattia commented 2 years ago

Thanks for that feedback, and for your interest/enthusiasm to help out!

I didn't do an extensive hyperparameter optimization but did try roughly the same parameters with the registered and non-registered meshes - just a few different learning rates, and HKS as well as XYZ. Im not using a batchsize - I set it as None, this is correct, right? Eg,

train_loader = DataLoader(train_bone_dataset, batch_size=None, shuffle=True)

I got the impression that diffusion net only took batch sizes of 1.

The predictions are changing during training time, but then once its done an epoch of training and runs through the validation set, it produces the exact same result for all meshes in the validation data. When it gets back to the validation set on the end of the next epoch it might have a slightly different prediction than the previous epoch, but its still consistent for all of the examples. Based on your examples all using consistent hyperparameters and doing well I was naively assuming things should be really "easy" and I might just be doing something silly - from your comments, Im guessing a good hyperparameter optimization will go a long way.

The labels are roughly balanced - 60:40 at worst, and even that imbalance is likely due to random splitting. And yes, I was using the smoothed log loss - I used the shrec11.py example as my template for my first runs.

-> I can move the below question to another issue if you think that is better.

A question about the potential for generative modelling. Do you think this is feasible with Diffusionnet? I was thinking about something like an autoencoder (or VAE) where diffusion net is used as the encoder maybe followed by a few dense layers that lead to a latent space of a fixed size. Then, the simplest decoder would just use dense layers to re-create the x/y/z positions of the nodes, or alternatively, I did wonder if there could be another set of diffusion net layers on the decoder side that leads to the final x/y/z positions of the nodes.

The comment at the end of the manuscript:

"Our networks are intentionally agnostic to local discretization, and thus may not be suited for tasks where one learns some property of the local discrete structure, such as denoising or mesh modification."

made me think diffusion net might be ill-suited to this particular autoencoder style problem. Curious on your thoughts.

Than

nmwsharp commented 2 years ago

hyperparameters

The settings you're using sound reasonable to me. Indeed this code is configured for batch size of 1 (so a None in the dataloader). The main hyperparameter that I tweaked during experiments was simply the network size, from a 32-width network for small problems to a 256-width network for large ones. And for the most part, the difference is just polishing off the last few error percent, like "92% accuracy" to "94% accuracy", etc.

The predictions are changing during training time, but then...

I'm really surprised to hear this behavior, because it's unlike anything I've experienced with this code. Is it possible that something funky is happening with dataloaders or caching or something which is causing the network to be applied to the same input data each time on the validation set? I'm happy to check out some code here if there's anything that's easy to share. This sounds like a problem that DiffusionNet should totally work on, so I'm interested to understand what is going on.

Generative modeling

This is also super interesting! I have pondered these problems a lot, but not actually done many experiments :) What you suggest with the dense layers would certainly be a very reasonable starting point. Using DiffusionNet blocks as a decoder is a bit trickier, because you have to pick a geometric domain upon which to evaluate diffusion etc. Perhaps you could just pick some template or base domain if your shapes are all variations on a base shape... I'm curious to see how this would work out.

The discretization agnostic properties are great for being an encoder, but on the decoder there are indeed still some missing pieces to build totally general discretization-agnostic decoders. I'd say this is very much an open problem. But for cases where there is a reasonable "base" shape, perhaps we can use the DiffusionNet block decoders over the base domain and everything works just fine, I'm not sure!

gattia commented 2 years ago

Thanks for the feedback! I've done a fair amount of HP optimization and some playing around and I have a few things to report re: predicting sex on bones.

Takehome

Normalization is HUGE! I did tons of HP testing/optimization with all possible combinations (without normalizing). The max I could get in-terms of accuracy of predicting sex was ~80% (logistic regression on PCA latent representations is ~90%). And I don't even think that was reproducible - I think it was a "best run". Then, I added normalization and it crushed it ~95% accuracy (this is a small test set, so might change on bigger, but still I think this is a good sign).

Normalization also made HKS work - without normalization, HKS was stagnant and didn't do anything. Might improve to better than chance for 1-2 epochs at start, then flatline (even if trained for 1k epochs). With normalization it does do some learning. However, HKS seems more picky than XYZ as input, in my case. For XYZ LR 1e-3 and 1e-4 both obtain test accuracy > 85%. HKS can get ~70% accuracy with 1e-4, and pretty much doesn't learn at 1e-3 or 1e-5. Also, even at 70% it has much greater overfitting vs. XYZ (train accuracy is like 95% for the test accuracy of 70%). This might be because I did most of my HP optimization (n_blocks, width, dropout, etc.) using XYZ.

So far, the best results are from parameters:

I havent re-run extensive HP optimization on n_blocks, C_width, label_smooth, or features since starting to use normalize. So, that might eek out a bit of performance, but given the small test-set I am using right now I wouldn't lean to much on any improvement from here.

Using DiffusionNet blocks as a decoder is a bit trickier,

I'm going to be trying the generative modeling next. I'll report back when I have some sense of how it works. For my usual workflow I register a template to all other meshes anways, so using the topology of that template mesh was my first thought. I'll try this as an option as I go along.

Note for others

One thing I found different from many other models is that DiffusionNet really does take some time to "warmup". It doesn't matter what case, it doesn't look like much if anything is being learned for ~50 epochs (my dataset is ~220 examples, so like 10k iterations). This might indicate that to eek out every grain of benefit fine-tuning the learning rate, schedule, etc. could be helpful.

nmwsharp commented 2 years ago

Awesome, thanks for sharing these results!

Just to double-check, when you say normalizing do you mean centering/unit-scaling shapes before processing? If so what you say makes sense, it has also been my experience as well. It can make a huge difference! In fact, I have encountered cases (with DiffusionNet as well as other methods) where even slight differences of how normalization is performed (mean shift translation vs. bounding box translation) can make a big difference. It would be great to come up with representations that are more robust to scaling etc.

time to "warmup"

Interesting! Thanks for sharing. One possible explanation here is that it is taking a long time to gradually push the diffusion times to the right range. Hard to say though.

kitchell commented 1 year ago
Screen Shot 2022-07-25 at 10 29 35 AM

Hi I wanted to share that I am seeing the same 'warm up' issue as well. I'm trying to use diffusion net with 3D point clouds of the brain surface. Initially just doing a basic sex prediction (m v f) to get a sense of parameters to use and it seems to sit and not do much for over 200 epochs before finally learning.

The point clouds have ~40k vertices (downsampled from over 120k) and I have ~900 samples for training, fairly evenly weighted.

I also found that lr of 1e-3 did not work at all, 1e-4 does work and a decay rate of .8 works better than .5. 300 eigenvalues works better than 128, but with 600 eigenvalues nothing was learned at all.

gattia commented 1 year ago

Thanks for sharing! @kitchell Did you find XYZ or HKS to make much a difference for you?