Open ricomnl opened 1 year ago
Hi Rico,
firstly, thank you so much for commenting and looking over the preprint! It is really nice to have someone else look into this independently and compare notes :)
You are right, the code in the repo didn't reflect what I was doing while plotting (I plotted the figures in R
by importing the raw rankings). While updating the code with one other change I have fixed the scaled rankings. Code change here and the changes in the notebook.
Relevant to this, I modified the way I am looking at those rankings, and actually, the better interpretation is that once you provide all genes, Geneformer is quite good at reconstructing the ranking. This however does not translate to good performance with cell embeddings. We will update the preprint shortly. This is how this figure will look in the current evaluation setup.
Now, coming back to your point - two things can be true - the model can be more uncertain about the calls and still be relatively more correct in those cases. In a way, it should be easier to output the position of the highly expressed gene as it is probably observed in data more frequently.
Another way we can think about those rankings, and the errors in them is - how often are the genes missing from the rankings completely? We see that for highly expressed genes (based on their mean input ranking position) they miss from the ranking less frequently than those lowly expressed. This is expected as the genes less expressed would be less represented in the data (i.e. they would not appear as frequently as the highly expressed genes).
I think it will be very interesting to look into the biological interpretation of this - for example, what are the genes that Geneformer often replaces the true value with?
Also, would you mind giving me a bit more background as to how you your experiments? Did you pass full rankings or did you mask some values (if so are you evaluating the performance on masked values only)?
And in "Here I'm plotting the sliding window accuracy and f1 score (50 genes at a time) from the top of the list to the bottom. It also suggests that the model does better for the highly expressed genes." do you mean lowly?
thanks for the prompt response! regarding your first point, I saw that in the code you now remove all the tokens that are absent in either input or output (code); I think I'd only remove the genes that are absent in the input to penalize the model for not predicting them β but I tried this and the results don't change a lot.
Interestingly, in the results in your notebook the predictions are strongly correlated but not "perfect". I independently reproduced the resullts with some slightly adapted code for a pbmc dataset and the predictions are pretty much perfect (spearmanr of 0.98). I also changed the procedure to calculate the mean rank and, for a given gene, I am calculating it based on only the cells where it is expressed (see my adapted code of your loop below). When doing that, this nonzero median rank performs a lot better than before.
This however does not translate to good performance with cell embeddings.
I would say this is not unexpected as works like SentenceBert and Detlefsen et al. 2022 have shown. BERT models do poorly when you just average their embeddings.
In a way, it should be easier to output the position of the highly expressed gene as it is probably observed in data more frequently.
That's what I thought initially too so the results surprised me. Another view is that the model is predicting something slightly better than the average rank for each gene because then it would still do really well on all the lowly ranked genes and poorly on the highly ranked ones (90% of nonzero values in a cell are 1s and we're dividing them by a global scaling factor so their sequence will always be the same).
Did you pass full rankings or did you mask some values (if so are you evaluating the performance on masked values only)?
I'm passing full rankings of max_length=2048
and then mask out the predictions with the same attention mask that I fed into the model to only evaluate predictions of genes and not of padding tokens.
And in "Here I'm plotting the sliding window accuracy and f1 score (50 genes at a time) from the top of the list to the bottom. It also suggests that the model does better for the highly expressed genes." do you mean lowly?
yes, good catch
Adapated loop starting from here:
n_cells = in_rankings.shape[0]
# Find unique tokens across all arrays
unique_tokens = np.intersect1d(in_rankings, out_rankings)
unique_tokens = unique_tokens[unique_tokens!=0] # remove padding
# Number of unique tokens and cells
n_tokens = len(unique_tokens)
# Initialize tensors with zeros (will fill in actual values later)
in_ranks = np.zeros((n_tokens, n_cells))
out_ranks = np.zeros((n_tokens, n_cells))
for j in range(n_cells):
for i, token in enumerate(unique_tokens):
pos_in_in = np.where(in_rankings[j] == token)[0]
# the question here is wether the token apears multiple times
# run a notebook and look at the outputs
pos_in_out = np.where(out_rankings[j] == token)[0]
if pos_in_in.shape[0] > 0:
in_ranks[i, j] = pos_in_in[0] + 1 # 1-based index
if pos_in_out.shape[0] > 0:
out_ranks[i, j] = np.rint(np.mean(pos_in_out)) + 1 # 1-based index
in_mask = in_ranks!=0
out_mask = out_ranks!=0
# Convert so that the lowest rank is 0 and the highest is 1
in_ranks = in_ranks.max(axis=0) - in_ranks
out_ranks = out_ranks.max(axis=0) - out_ranks
# for mean ranks, we don't want to count the zeros
mean_ranks_nan = in_ranks.copy()
mean_ranks_nan[~in_mask] = np.nan
mean_ranks = np.nanmedian(mean_ranks_nan, axis=1)
mean_ranks = mean_ranks[:, None].repeat(in_ranks.shape[1], axis=1)
# Get "percent_rank" equivalent
in_ranks /= in_ranks.max(axis=0)
out_ranks /= out_ranks.max(axis=0)
mean_ranks /= mean_ranks.max(axis=0)
I will have a better look at the response in the coming days, but just from skimming it through:
regarding your first point, I saw that in the code you now remove all the tokens that are absent in either input or output (code);
I use torch.logical_or(in_ranks > 0, out_ranks > 0)
which filters out only the cases where the gene is both missing from the input and the output, i.e. I keep both the predictions of genes not present in the input ranking as well as missing genes absent from the output.
ah of course, good catch π
Quick question about Figure 6c (above). From the paper, it seems that the highly expressed genes are of rank 1.0 and the lowly expressed genes of rank 0.0. This would be in line with the results of scGPT β the model does well at predicting highly expressed genes and poorly at predicting lowly expressed ones. Now, as far as I understand, in Geneformer, the highly expressed genes are at the front of the list (source) and in your code you eventually divide the rank by the maximum rank (source) so, wouldn't that mean that Geneformer does better at predicting lowly expressed genes and worse at predicting highly expressed ones?
Here are also two plots that I created independently that would confirm my assessment:
Here I'm taking the cross entropy with
reduction='none'
and then plot the mean cross entropy across a batch of cells. It suggests that the model is less confidence for higher ranked genes.Here I'm plotting the sliding window accuracy and f1 score (50 genes at a time) from the top of the list to the bottom. It also suggests that the model does better for the highly expressed genes.
Can you confirm that you didn't alter the way Geneformer ranks genes in your assessment?