Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug fixing for MPNetEntityInjector and run file #5

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions kaping/entity_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def sentence_embedding(self, texts: list):
"""
return self.model.encode(texts)

def top_k_triple_extractor(self, question: np.ndarray, triples: np.ndarray, k=10, random=False):
def top_k_triple_extractor(self, question: np.ndarray, triples_emb: np.ndarray, triples: list, k=10, random=False):
"""
Retrieve the top k triples of KGs used as context for the question

Expand All @@ -42,14 +42,14 @@ def top_k_triple_extractor(self, question: np.ndarray, triples: np.ndarray, k=10
:return: list of triples
"""
# in case number of triples is fewer than k
if len(triples) < k:
k = len(triples)
if len(triples_emb) < k:
k = len(triples_emb)

if random:
return random.sample(infos, k)
return random.sample(triples, k)

# if not the baseline but the top k most similar
similarities = cosine_similarity(question, triples)
similarities = cosine_similarity(question, triples_emb)
top_k_indices = np.argsort(similarities[0])[-k:][::-1]

return [triples[index] for index in top_k_indices]
Expand Down Expand Up @@ -90,8 +90,9 @@ def __call__(self, question: list, triples: list, k=10, random=False, no_knowled
emb_triples = self.sentence_embedding(triples)

# retrieve the top k triples
top_k_triples = self.top_k_triple_extractor(emb_question, emb_triples, k=k, random=random)
top_k_triples = self.top_k_triple_extractor(emb_question, emb_triples, triples, k=k, random=random)

# create prompt as input
return self.injection(question, top_k_triples)
# when injecting, the question should be string
return self.injection(question[0], top_k_triples)

56 changes: 28 additions & 28 deletions kaping/entity_verbalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,34 +28,34 @@ def _extract_triplets(self, text: str):
"""

triplets = []
relation, subject, relation, object_ = '', '', '', ''
text = text.strip()
current = 'x'
for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
if token == "<triplet>":
current = 't'
if relation != '':
triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
relation = ''
subject = ''
elif token == "<subj>":
current = 's'
if relation != '':
triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
object_ = ''
elif token == "<obj>":
current = 'o'
relation = ''
else:
if current == 't':
subject += ' ' + token
elif current == 's':
object_ += ' ' + token
elif current == 'o':
relation += ' ' + token
if subject != '' and relation != '' and object_ != '':
triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
return triplets
relation, subject, relation, object_ = '', '', '', ''
text = text.strip()
current = 'x'
for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
if token == "<triplet>":
current = 't'
if relation != '':
triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
relation = ''
subject = ''
elif token == "<subj>":
current = 's'
if relation != '':
triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
object_ = ''
elif token == "<obj>":
current = 'o'
relation = ''
else:
if current == 't':
subject += ' ' + token
elif current == 's':
object_ += ' ' + token
elif current == 'o':
relation += ' ' + token
if subject != '' and relation != '' and object_ != '':
triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
return triplets


def text_relation(self, text: str):
Expand Down
3 changes: 2 additions & 1 deletion kaping/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def pipeline(config, question: str, device=-1):
knowledge_triples.extend(verbalizer(entity, entity_title))

# entity injection as final prompt as input
prompt = injector(question, knowledge_triples, k=config.k, random=config.random, no_knowledge=config.no_knowledge)
# quetion should be a list of string
prompt = injector([question], knowledge_triples, k=config.k, random=config.random, no_knowledge=config.no_knowledge)

return prompt

Expand Down
2 changes: 1 addition & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def main():
results.append(qa_pair)

# evaluate to calculate the accuracy
evaluated.append(evaluated(qa_pair.answer, predicted_answer))
evaluated.append(evaluate(qa_pair.answer, predicted_answer))

msg = ""
if args.no_knowledge:
Expand Down