BinWang28 / SBERT-WK-Sentence-Embedding

IEEE/ACM TASLP 2020: SBERT-WK: A Sentence Embedding Method By Dissecting BERT-based Word Models
Apache License 2.0
177 stars 27 forks source link

Vectorize reshape operation #7

Closed JohnGiorgi closed 4 years ago

JohnGiorgi commented 4 years ago

Hi, thanks for making your code available! I am very interested in trying out the pooling method in my own sentence encoding project.

I got a little hung up on this code in sent_emb.py:

https://github.com/BinWang28/SBERT-WK-Sentence-Embedding/blob/ca65d43cbed637f2a7129669110782e9a6f7fe3b/sen_emb.py#L109-L112

It seems this is basically just a reshape operation that swaps the dimension representing the number of layers with the dimension representing the number of sentences. That can be vectorized and replaced with the following one-liner:

all_layer_embedding = torch.stack(features).permute(1, 0, 2, 3).cpu().numpy()

This is easier to read and should be slightly faster as it swaps a nested for loop for a series of vectorized operations. I checked that the two approaches are equivalent:


# PROPOSED APPROACH
all_layer_embedding_approach_2 = torch.stack(features).permute(1, 0, 2, 3).cpu().numpy()

# CURRENT APPROACH
features = [layer_emb.cpu().numpy() for layer_emb in features]
all_layer_embedding = []
for i in range(features[0].shape[0]):
    all_layer_embedding.append(np.array([layer_emb[i] for layer_emb in features]))

# CHECK THAT THEY ARE EQUAL
assert np.array_equal(all_layer_embedding, all_layer_embedding_approach_2)

This PR just implements that change. I also black formatted sent_emb.py :smile:

BinWang28 commented 4 years ago

Thanks so much for helping improve the code! I also have double-checked!