Bitbol-Lab / ProtMamba-ssm

ProtMamba: a homology-aware but alignment-free protein state space model
https://www.biorxiv.org/content/10.1101/2024.05.24.595730v1
Apache License 2.0
44 stars 7 forks source link

Batch embeddings differ from individual processing #9

Closed david-arredondo closed 1 month ago

david-arredondo commented 1 month ago

I am evaluating the use of the last vector in the last hidden layer as an embedding for a given input sequence.

I noticed that if I pass multiple sequences in a batch, I get a different embedding than if I pass them in one at a time.

For example:

tokens = tokenizer(['MEEP','MLEP'],concatenate=False)
pos_ids = [torch.tensor([0,1,2,3,4]) for i in range(2)]

tokens = torch.nn.utils.rnn.pad_sequence(tokens, batch_first=True, padding_value=AA_TO_ID["<pad>"]).to('cuda')
pos_ids = torch.stack(pos_ids).to('cuda')

hidden_layers = model(input_ids = tokens, position_ids = pos_ids, save_layer = [16])[16]
embeddings = hidden_layers[:,-1,:] #last token of each sequence

will return different embeddings for the first sequence than:

tokens = tokenizer(['MEEP'])
pos_ids = [torch.tensor([0,1,2,3,4])]

tokens = torch.nn.utils.rnn.pad_sequence(tokens, batch_first=True, padding_value=AA_TO_ID["<pad>"]).to('cuda')
pos_ids = torch.stack(pos_ids).to('cuda')

hidden_layers = model(input_ids = tokens, position_ids = pos_ids, save_layer = [16])[16]
embeddings = hidden_layers[:,-1,:] #last token of each sequence
CyrilMa commented 1 month ago

Hi David,

Thank you for pointing that out.

I ran some tests, and maybe it's the propagation of numerical errors that is responsible for that behavior. I ran the test with the model torch.bfloat16 and was able to reproduce what you saw. I tried again in torch.float32 and was not encountering the same error.

I haven't found any reference to this issue on the Mamba forum, so not completely sure why.

You can try loading the model using:

model = load_model(checkpoint,
                   model_class=MambaLMHeadModelwithPosids,
                   device=device,
                   dtype=torch.float32,
                   checkpoint_mixer=False
                   ).eval()

instead of torch.bfloat16 and see if it also works better for you.