Gaurav-Pande / AES_DL

Automated Essay Scoring using BERT
http://www.gauravpande.in/AES/
45 stars 15 forks source link

can you please add a script to test the bert trained model ? #4

Open ankitkr3 opened 3 years ago

ankitkr3 commented 3 years ago

Thanks for this amazing work, can you please add a script for testing the saved lstm model with bert featurizer? @Gaurav-Pande

ankitkr3 commented 3 years ago

Just for more clarity so for example i have saved set_count_1.h5 model for the first set and similarly i want to test more data with the help of this model for the same set. but I am not able to produce a script.

ankitkr3 commented 3 years ago

I have worked on a script, let me know if it make sense:

`import time import torch import transformers as ppb import warnings warnings.filterwarnings('ignore')

cuda = torch.device('cuda')

For DistilBERT:

model_class, tokenizer_class, pretrained_weights = (ppb.DistilBertModel, ppb.DistilBertTokenizer, 'distilbert-base-uncased') tokenizer = tokenizer_class.from_pretrained(pretrained_weights) model = model_class.from_pretrained(pretrained_weights) with torch.cuda.device(cuda): test_essays = demo_df['essay'] sentences = [] tokenize_sentences = [] train_bert_embeddings = []

tokenized_test = test_essays.apply((lambda x: tokenizer.encode(x, add_special_tokens=True ,max_length=200)))

max_len = 0 for i in tokenized_test.values: if len(i) > max_len: max_len = len(i) padded_test = np.array([i + [0]*(max_len-len(i)) for i in tokenized_test.values]) attention_mask_test = np.where(padded_test != 0, 1, 0) test_input_ids = torch.tensor(padded_test)
test_attention_mask = torch.tensor(attention_mask_test)

with torch.no_grad(): last_hidden_states_test = model(test_input_ids, attention_mask=test_attention_mask)

test_features = last_hidden_states_test[0][:,0,:].numpy()

train_x,train_y = train_features.shape test_x,test_y = test_features.shape

testDataVectors = np.reshape(test_features,(test_x,1,test_y))

lstm_model.load_weights("./model_weights/final_lstm1.h5") preds = lstm_model.predict(testDataVectors) print(int(np.around(preds))) `