huBioinfo / CytoCommunity

A spatial omics data analysis tool that enables both unsupervised and supervised discovery of complex tissue cellular neighborhoods from cell phenotypes.
MIT License
20 stars 10 forks source link

Potential bugfix: Concatenation error in supervised learning file #10

Open edridgedsouza opened 3 months ago

edridgedsouza commented 3 months ago

I got the following error: the array at index 0 has size 4 and the array at index 1 has size 5 when running the Step2 script in the supervised learning set. My data has 16 images with 2 graph node categories. Hyperparameters were as follows:

## Hyperparameters
Num_TCN = 15
Num_Times = 20
Num_Folds = 8
Num_Epoch = 100
Embedding_Dimension = 512
LearningRate = 0.0001
MiniBatchSize = 8
beta = 0.9

When running, I received the error which I traced to the test() function. It initialized pr_Table as np.zeros([1,4]), but when we attempt to concatenate it with pred_info, we see that pred_info has 5 elements instead of 4. I searched all the places where pr_Table gets used after this part of the script, and to my knowledge, it simply gets serialized to a file but never actually used again. I fixed the issue by changing the initialization to np.zeros([1,5]) and it seems to be running without issue. Is this bugfix valid or does it cause some downstream effect that I have not foreseen?

edridgedsouza commented 3 months ago

After playing around, it seems that the size of this array should actually be dynamic, because the output of the model depends on the parameter out_channels when you create your Net. Also worth noting: the out_channels parameter is set by dataset.num_classes which inherits InMemoryDataset.num_classes(). So instead of [1,5] perhaps it should be changed to [1, data.num_classes + 2].

Worth noting that the num_classes() function, by default, sets the num_classes to be the maximum value plus 1. In other words, it actually makes a great difference when you create your GraphLabel files whether you use a zero-indexed or one-indexed numbering system! I had been getting suspiciously low accuracy numbers only to find that this was an unintended consequence of labeling my conditions 1 and 2, causing the Net to have 3 total outputs, rather than 0 and 1, which would cause it to have 2. Perhaps a little note in the readme could clarify this point.