sassoftware / python-dlpy

The SAS Deep Learning Python (DLPy) package provides the high-level Python APIs to deep learning methods in SAS Visual Data Mining and Machine Learning. It allows users to build deep learning models using friendly Keras-like APIs.
Apache License 2.0
224 stars 131 forks source link

How to get "BERT embedding" of each document using DLPy? #384

Closed riow1983 closed 1 year ago

riow1983 commented 1 year ago

I want to confirm the way to get "BERT embedding" of each document using DLPy.

If I understand correctly, a BERT model compiled by DLPy inherits DLPy's model class and DLPy's model class does have get_features method. One can get intermediate layer's outputs via get_features (https://sassoftware.github.io/python-dlpy/generated/dlpy.model.Model.get_features.html) which means one can also get BERT embedding via the same method.

import numpy as np
import swat
from dlpy.transformers.bert_model import BERT_Model
from dlpy.transformers.bert_utils import bert_prepare_data
conn = swat.CAS(...)

bert = BERT_Model(conn, ...)
bert.compile(...)
bert.fit(...) # train BERT model here

_, val = bert_prepare_data(conn, ...) # val points the CAS table that contains sentence of each document that I'm aiming to vectorize.
x,y = bert.get_features(data=val, dense_layer='bert_pooling', target='_target_0_', textParms=bert.get_text_parameters())
# I need BERT's last hidden states, so I choose 'bert_pooling' layer which is the last layer just before the task specific head. 
# Here, x is (num_documents, num_tokens * num_embedding_dimensions) matrix.
x = x.reshape(-1, 512, 768) # Here 512 is num_documents, 768 is num_embedding in this case.
x = np.mean(x, axis=1) # mean across tokens, which is mean pooling method
# Here, x is (num_documents, num_embedding_dimensions) matrix.

Then, I assume x can be interpreted as BERT embedding of each document. Correct me if I'm missing something.

LipingCai commented 1 year ago

I think your assumption sounds right, but keep in mind that the get_features was originally designed for image data and the BERT model was added later and not being updated since 2019/2020, so there is no guarantee it all works as expected.

riow1983 commented 1 year ago

I understand. Thank you for your reply. I'd like to throw a feature request in order to get access to the guaranteed method. I just wrote x = np.mean(x, axis=1) to get mean pooling embedding. However, to be more precise I needed attention masks to exclude [PAD] tokens from the mean calculation. You would do this in PyTorch for example:

x = (x * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(1, keepdim=True)

Hope this equivalent along with "transformer-friendly" method to get intermediate embeddings will be implemented in near future.

dxq77dxq commented 1 year ago

Thank you for the suggestion. We're currently in the process of switching underlying engine. After we complete, we'll consider what new features to support.