diff --git a/kaping/entity_injection.py b/kaping/entity_injection.py
index 48978d7..0c9e928 100644
--- a/kaping/entity_injection.py
+++ b/kaping/entity_injection.py
@@ -31,7 +31,7 @@ def sentence_embedding(self, texts: list):
"""
return self.model.encode(texts)
- def top_k_triple_extractor(self, question: np.ndarray, triples: np.ndarray, k=10, random=False):
+ def top_k_triple_extractor(self, question: np.ndarray, triples_emb: np.ndarray, triples: list, k=10, random=False):
"""
Retrieve the top k triples of KGs used as context for the question
@@ -42,14 +42,14 @@ def top_k_triple_extractor(self, question: np.ndarray, triples: np.ndarray, k=10
:return: list of triples
"""
# in case number of triples is fewer than k
- if len(triples) < k:
- k = len(triples)
+ if len(triples_emb) < k:
+ k = len(triples_emb)
if random:
- return random.sample(infos, k)
+ return random.sample(triples, k)
# if not the baseline but the top k most similar
- similarities = cosine_similarity(question, triples)
+ similarities = cosine_similarity(question, triples_emb)
top_k_indices = np.argsort(similarities[0])[-k:][::-1]
return [triples[index] for index in top_k_indices]
@@ -90,8 +90,9 @@ def __call__(self, question: list, triples: list, k=10, random=False, no_knowled
emb_triples = self.sentence_embedding(triples)
# retrieve the top k triples
- top_k_triples = self.top_k_triple_extractor(emb_question, emb_triples, k=k, random=random)
+ top_k_triples = self.top_k_triple_extractor(emb_question, emb_triples, triples, k=k, random=random)
# create prompt as input
- return self.injection(question, top_k_triples)
+ # when injecting, the question should be string
+ return self.injection(question[0], top_k_triples)
diff --git a/kaping/entity_verbalization.py b/kaping/entity_verbalization.py
index 06d676b..eb74c42 100644
--- a/kaping/entity_verbalization.py
+++ b/kaping/entity_verbalization.py
@@ -28,34 +28,34 @@ def _extract_triplets(self, text: str):
"""
triplets = []
- relation, subject, relation, object_ = '', '', '', ''
- text = text.strip()
- current = 'x'
- for token in text.replace("", "").replace("", "").replace("", "").split():
- if token == "":
- current = 't'
- if relation != '':
- triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
- relation = ''
- subject = ''
- elif token == "":
- current = 's'
- if relation != '':
- triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
- object_ = ''
- elif token == "":
- current = 'o'
- relation = ''
- else:
- if current == 't':
- subject += ' ' + token
- elif current == 's':
- object_ += ' ' + token
- elif current == 'o':
- relation += ' ' + token
- if subject != '' and relation != '' and object_ != '':
- triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
- return triplets
+ relation, subject, relation, object_ = '', '', '', ''
+ text = text.strip()
+ current = 'x'
+ for token in text.replace("", "").replace("", "").replace("", "").split():
+ if token == "":
+ current = 't'
+ if relation != '':
+ triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
+ relation = ''
+ subject = ''
+ elif token == "":
+ current = 's'
+ if relation != '':
+ triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
+ object_ = ''
+ elif token == "":
+ current = 'o'
+ relation = ''
+ else:
+ if current == 't':
+ subject += ' ' + token
+ elif current == 's':
+ object_ += ' ' + token
+ elif current == 'o':
+ relation += ' ' + token
+ if subject != '' and relation != '' and object_ != '':
+ triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
+ return triplets
def text_relation(self, text: str):
diff --git a/kaping/model.py b/kaping/model.py
index 0f64445..19d2fd6 100644
--- a/kaping/model.py
+++ b/kaping/model.py
@@ -30,7 +30,8 @@ def pipeline(config, question: str, device=-1):
knowledge_triples.extend(verbalizer(entity, entity_title))
# entity injection as final prompt as input
- prompt = injector(question, knowledge_triples, k=config.k, random=config.random, no_knowledge=config.no_knowledge)
+ # quetion should be a list of string
+ prompt = injector([question], knowledge_triples, k=config.k, random=config.random, no_knowledge=config.no_knowledge)
return prompt
diff --git a/run.py b/run.py
index 789707a..dce4492 100644
--- a/run.py
+++ b/run.py
@@ -42,7 +42,7 @@ def main():
results.append(qa_pair)
# evaluate to calculate the accuracy
- evaluated.append(evaluated(qa_pair.answer, predicted_answer))
+ evaluated.append(evaluate(qa_pair.answer, predicted_answer))
msg = ""
if args.no_knowledge: