Open jatinvinkumar opened 1 year ago
I added the following code to the main function of run.py:
`training_data, valid_data, test_data, vocab, args, logger, saved_vocab_path = prepare_data() print("{}.{}.{}".format((len(training_data[0])), (valid_data[0]), len(test_data[0]))) valid_dataset = TextDataset(valid_data, vocab, max_seq_length=4000, min_seq_length=-1, sort=True, multilabel=False) print("valid_dataset: ", valid_dataset) valid_dataloader = TextDataLoader(dataset=valid_dataset, vocab=vocab, batch_size=8) print("valid_dataloader: ", valid_dataloader) best_model_path = "checkpoints/mimic-iii_2_full/RNN_LSTM_1_512.static.label.0.001.0.3_72df8e44d8921dd19f07bab290d6a868/best_model.pkl" print("=> loading best model '{}'".format(best_model_path))
parser = argparse.ArgumentParser()
# add arguments for the RNN module
parser.add_argument('--rnn_model', type=str, default='LSTM')
parser.add_argument('--mode', type=str, default='static')
parser.add_argument('--use_last_hidden_state', type=int, default=0)
parser.add_argument('--n_layers', type=int, default=1)
parser.add_argument('--bidirectional', type=int, default=1)
parser.add_argument('--dropout', type=float, default=0.3)
parser.add_argument('--hidden_size', type=int, default=512)
parser.add_argument('--attention_mode', type=str, default='label')
parser.add_argument('--embedding_size', type=int, default=100)
parser.add_argument("--embedding_file", type=str, default='data/embeddings/word2vec_sg0_100.model')
parser.add_argument("--embedding_mode", type=str, default="word2vec")
parser.add_argument("--joint_mode", type=str, default="hierarchical")
parser.add_argument("--level_projection_size", type=int, default=128)
parser.add_argument("--d_a", type=int, default=512)
parser.add_argument("--r", type=int, help="The number of hops for self attention", default=-1)
parser.add_argument("--use_regularisation", action='store_true', default=False)
parser.add_argument("--penalisation_coeff", type=float, default=0.01)
args, _ = parser.parse_known_args()
model = RNN(vocab, args)
model.to(vocab.device)
checkpoint = torch.load(best_model_path)
model.load_state_dict(checkpoint['state_dict'])
model.eval()
pred_probs = [[] for _ in range(vocab.n_level())]
true_labels = [[] for _ in range(vocab.n_level())]
ids = []
for text_batch, label_batch, length_batch, id_batch in \
tqdm(valid_dataloader, unit="batches", desc="Evaluating"):
text_batch = text_batch.to(device)
for idx in range(len(label_batch)):
label_batch[idx] = label_batch[idx].to(device)
if type(length_batch) == list:
for i in range(len(length_batch)):
length_batch[i] = length_batch[i].to(device)
else:
length_batch = length_batch.to(device)
true_label_batch = []
for idx in range(len(label_batch)):
true_label_batch.append(label_batch[idx].cpu().numpy())
true_labels.extend(true_label_batch)
ids.extend(id_batch)
with torch.no_grad():
output, attn_weights = model(text_batch, length_batch)
print("output: ", output)
pred_labels = [None] * len(output)
for label_lvl in range(len(output)):
output[label_lvl] = torch.softmax(output[label_lvl], 1)
top_k_values, top_k_indices = torch.topk(output[label_lvl], 1, dim=1)
top_k_indices = top_k_indices.detach().cpu().numpy()
# Convert indices to labels
batch_pred_labels = []
for index in top_k_indices:
batch_pred_labels.append(vocab.index2word[index.item()])
pred_labels[label_lvl] = batch_pred_labels
pred_probs[label_lvl].extend(output[label_lvl].tolist())
print("predicted_labels: ", pred_labels)
input("Press Enter to continue...")`
But the output is like the following: predicted_labels: [['0pt', '0955j', '1007f', '0x20mm', '07assessment', '0units', '0955j', '0final'], ['1ary', '140and', '10fr', '1ary', '1every', '18micrograms', '140and', '15lnrb']]
It does not seem to correlate to actual codes. Could you please advise?
Here are my model stats:
^MTraining at epoch #14: 0%| | 0/5965 [00:00<?, ?batches/s]^M ^M^MTraining at ep$ ^MEvaluating: 0%| | 0/204 [00:00<?, ?batches/s]^MEvaluating: 0%| | 1/204 [00:00<02:07, 1.59batches/s]^MEvaluating: 1%| | 2/2$ ^MEvaluating: 0%| | 0/422 [00:00<?, ?batches/s]^MEvaluating: 0%| | 1/422 [00:00<04:53, 1.44batches/s]^MEvaluating: 0%| | 2/4$ 07:37:41 INFO Loss on Train at epoch #14: 0.00541, micro_f1 on Valid: 0.5789 07:37:41 INFO [CURRENT BEST] (level_1) micro_f1 on Valid set: 0.58033 07:37:41 INFO Early stopping: 6/6 07:37:47 WARNING Early stopped on Valid set! 07:37:47 INFO =================== BEST =================== 07:37:47 INFO Results on Valid set at epoch #8 with Averaged Loss 0.00728 07:37:47 INFO ======== Results at level_0 ======== 07:37:47 INFO Results on Valid set at epoch #8 with Loss 0.02295: [MICRO] accuracy: 0.53089 auc: 0.98804 precision: 0.73734 recall: 0.65471 f1: 0.69357 P@1: 0 P@5: 0 P@8: 0 P@10: 0 P@15: 0 [MACRO] accuracy: 0.19122 auc: 0.9377 precision: 0.27899 recall: 0.24661 f1: 0.2618 P@1: 0.97241 P@5: 0.89442 P@8: 0.82342 P@10: $
07:37:47 INFO ======== Results at level_1 ======== 07:37:47 INFO Results on Valid set at epoch #8 with Loss 0.00519: [MICRO] accuracy: 0.40878 auc: 0.98861 precision: 0.64014 recall: 0.53074 f1: 0.58033 P@1: 0 P@5: 0 P@8: 0 P@10: 0 P@15: 0 [MACRO] accuracy: 0.05843 auc: 0.92904 precision: 0.08146 recall: 0.07965 f1: 0.08055 P@1: 0.91478 P@5: 0.80809 P@8: 0.73896 P@10: $
07:37:47 INFO Results on Test set at epoch #8 with Averaged Loss 0.00751 07:37:47 INFO ======== Results at level_0 ======== 07:37:47 INFO Results on Test set at epoch #8 with Loss 0.02344: [MICRO] accuracy: 0.52845 auc: 0.98808 precision: 0.73451 recall: 0.65322 f1: 0.69148 P@1: 0 P@5: 0 P@8: 0 P@10: 0 P@15: 0 [MACRO] accuracy: 0.19495 auc: 0.93349 precision: 0.28736 recall: 0.25757 f1: 0.27165 P@1: 0.96263 P@5: 0.89306 P@8: 0.82514 P@10: $
07:37:47 INFO ======== Results at level_1 ======== 07:37:47 INFO Results on Test set at epoch #8 with Loss 0.00539: [MICRO] accuracy: 0.40081 auc: 0.98835 precision: 0.63183 recall: 0.52295 f1: 0.57226 P@1: 0 P@5: 0 P@8: 0 P@10: 0 P@15: 0 [MACRO] accuracy: 0.06157 auc: 0.92049 precision: 0.0927 recall: 0.08892 f1: 0.09077 P@1: 0.91133 P@5: 0.80623 P@8: 0.73528 P@10: $
07:37:48 INFO => loading best model 'checkpoints/mimic-iii_2_full/RNN_LSTM_1_512.static.label.0.001.0.3_72df8e44d8921dd19f07bab290d6a868/best_model.pkl'
Thank you for sharing this work! I was able to train the model. Could you please advise on the steps for performing an inference where the output shows the labels as a string?