KrishnaswamyLab / AAnet

Archetypal Analysis network (AAnet)
https://arxiv.org/abs/1901.09078
Other
29 stars 9 forks source link

predicting archetype membership of new points? #6

Open orbidder opened 3 years ago

orbidder commented 3 years ago

Thanks for publishing your implementation, I've really enjoyed your work. One addition to the tutorial notebook that I'd love to see is the application of the trained model to new data to predict their membership to each of the identified archetypes. I'm not sure how to do this and it would be really helpful to see it demonstrated. Currently my idea is to use the get distance function you've included to measure the distance between the new data point and each of the archetypes, then express membership as a ratio, the numerator of which would be the distance between the new point and a given archetype, and denominator the sum of distances between the new point and each archetype. That way the membership for each point across all archetypes should sum to 1. Is this the correct approach?

EDIT: I took a look in the AAnet.py script itself, and in fairness to you, I think from your annotation I was able to figure out a prediction method, which I'll paste below. Is this the correct procedure? If so, perhaps this could be added, along with a little more explanation, to the end of the tutorial?

# get testing data and apply model to predict membership
data_test = mnist.test.images
data_test = (data_test * 2) - 1 # norm for tanh

# get only digit 4, single digit
digit = 4
idx_digit = mnist.test.labels == digit
data_test = data_test[idx_digit,]

# predict membership
new_archetypal_coords = model.data2at(data_test)

# get the index of best membership for each point
labels = np.argmax(new_archetypal_coords, axis=1)

# plot points in PCA space, colored by archetype
model.plot_pca_ats_data(data_test, c=labels)
dburkhardt commented 3 years ago

Hi @orbidder ! Thanks for your interest in AAnet. I'm a little confused by your question, what's the goal here? I don't think it makes sense to "predict the membership of each point," the goal of archetypal analysis is to have a continuously varying measure of how similar a point is to each archetype. By assigning each point to it's closest archetype, I think you're going back to a discretized framework that's somewhat the opposite of what archetypal analysis is about.

orbidder commented 3 years ago

Hi @orbidder ! Thanks for your interest in AAnet. I'm a little confused by your question, what's the goal here? I don't think it makes sense to "predict the membership of each point," the goal of archetypal analysis is to have a continuously varying measure of how similar a point is to each archetype. By assigning each point to it's closest archetype, I think you're going back to a discretized framework that's somewhat the opposite of what archetypal analysis is about.

You're right, and I intend on keeping the degree of membership information. The purpose of taking the argmax was really for the purposes of plotting the data. My goal was to use the fitted model to predict membership to each of the k archetypes, and I think data2at does that, correct?

As an aside, to find the number of archetypes, in your paper you suggest using MSE and the elbow method to find the inflection point. When applied to the MNIST dataset, its difficult to find as obvious an elbow as you illustrate in Figure 5. Is there an alternative method, or an alternative to MSE that could be used here?

dburkhardt commented 3 years ago

Just to be clear, are trying to run on all of MNIST? I would only apply to a single digit. I also think that it might be difficult to characterize the "correct" number of archetypes for MNIST. Image data flattened into a vector represents a very strange space and the idea that there are a 4 vs 5 archetypes is difficult to determine. The elbow method was mostly a suggestion and may not work for all datasets

orbidder commented 3 years ago

Thanks for your reply. I'm not trying to run on all MNIST digits, just the digit "4" as you do in your tutorial (see code above). In the paper, Fig 5 panel c) shows some clear elbows when applied to MNIST. The green line in the elbow plot is for the digit "4", right? I wasn't able to find such clear elbows when I ran AAnet myself. I think I know why though. I was plotting MSE for each value of k, where as I see now you have the Relative MSE loss. That could be it, depending on how you calculated the 'Relative MSE loss'. Did you just min/max transform the MSE values for each of k runs? Is that how you scaled them between 0 and 1? Also, even though a lot of the examples you show for the digit "4" have 3 archetypes, looking at Figure 5 the optimum k was 4, correct? Sorry to get in to the minutiae of a paper you wrote so long ago, it's just relevant to what I'm currently doing.