raphaelsty / mkb

Knowledge Base Embedding By Cooperative Knowledge Distillation
65 stars 3 forks source link

Inference on `SentenceTransformer` #8

Closed utsavdarlami closed 1 year ago

utsavdarlami commented 1 year ago

From your example i was able to train a SentenceTransformer. I used this dummy data

train = [
    ("jaguar", "cousin", "cat"),
    ("tiger", "cousin", "cat"),
    ("dog", "cousin", "wolf"),
    ("dog", "angry_against", "cat"),
    ("wolf", "angry_against", "jaguar"),
]

valid = [
    ("cat", "cousin", "jaguar"),
    ("cat", "cousin", "tiger"),
    ("dog", "angry_against", "tiger"),
]

test = [
    ("wolf", "angry_against", "tiger"),
    ("wolf", "angry_against", "cat"),
]

Given a new head node "big cat" which is unseen one, I wanted to know if we can make certain inference like.

get_top_tails( 
   k = 2,
   model = model,
   head = "big cat",
   relation = 'cousin'
)

k being number of closest or top tails.

Does this make sense or have i completely misunderstood the SentenceTransformer model ?

raphaelsty commented 1 year ago

Hi @utsavdarlami,

You are right; this is the objective of the BLP model (i.e., SentenceTransformer as a TransE).

Since the objective of MKB is only to train the model, it will be necessary to use a third-party library to perform the inference.

Cherche will encode the query by adding the embedding of the entity text and the embedding of the relation and will search for the nearest neighbors among the set of documents (entities) using faiss KD-Tree.

Here is a pseudo code with Cherche (pip install cherche), to perform inference on unseen entities using the BLP model. You will need a SentenceTransformer model trained with MKB and a relation embedding as a numerical vector (cousin).

import typing

import numpy as np
from cherche import retrieve
from sentence_transformers import SentenceTransformer

device = "cpu"

documents = [{"id": 0, "label": "jaguar"}, {"id": 1, "label": "tiger"}, {"id": 2, "label": "dog"}, {"id": 3, "label": "wolf"}, {"id": 4, "label": "cat"}]

# BLP wrapper
class BLP:
    def __init__(self, encoder, relation: typing.Optional[np.ndarray] = None) -> None:
        """CKB Embeddings"""
        self.encoder = encoder
        self.relation = relation

    def encode(self, sentence: typing.Union[str, list]) -> np.ndarray:
        """Encode tags."""
        return self.encoder.encode(sentence)

    def encode_query(self, sentence: typing.Union[str, list]) -> np.ndarray:
        """Encode query."""
        return self.encode(sentence=sentence) + self.relation # TransE scoring function

# Embedding of the relation cousin available from mkb model
relation =  np.array([1.2, ..., 9.4, 10.2, 100.2])

# SentenceTransformer available from MKB model
model = BLP(
    encoder=SentenceTransformer("model", device=device), relation=relation
)

retriever = retrieve.DPR(
    key="id",
    on=["label"],
    encoder=model.encode,
    query_encoder=model.encode_query,
    k=5,
)

retriever.add(documents)

retriever("big cat") # Retrieve entities that are most likely to complete the triple (big cat, cousin, ?)
utsavdarlami commented 1 year ago

Thank you very much.