diff --git a/kaping/entity_injection.py b/kaping/entity_injection.py index 48978d7..0c9e928 100644 --- a/kaping/entity_injection.py +++ b/kaping/entity_injection.py @@ -31,7 +31,7 @@ def sentence_embedding(self, texts: list): """ return self.model.encode(texts) - def top_k_triple_extractor(self, question: np.ndarray, triples: np.ndarray, k=10, random=False): + def top_k_triple_extractor(self, question: np.ndarray, triples_emb: np.ndarray, triples: list, k=10, random=False): """ Retrieve the top k triples of KGs used as context for the question @@ -42,14 +42,14 @@ def top_k_triple_extractor(self, question: np.ndarray, triples: np.ndarray, k=10 :return: list of triples """ # in case number of triples is fewer than k - if len(triples) < k: - k = len(triples) + if len(triples_emb) < k: + k = len(triples_emb) if random: - return random.sample(infos, k) + return random.sample(triples, k) # if not the baseline but the top k most similar - similarities = cosine_similarity(question, triples) + similarities = cosine_similarity(question, triples_emb) top_k_indices = np.argsort(similarities[0])[-k:][::-1] return [triples[index] for index in top_k_indices] @@ -90,8 +90,9 @@ def __call__(self, question: list, triples: list, k=10, random=False, no_knowled emb_triples = self.sentence_embedding(triples) # retrieve the top k triples - top_k_triples = self.top_k_triple_extractor(emb_question, emb_triples, k=k, random=random) + top_k_triples = self.top_k_triple_extractor(emb_question, emb_triples, triples, k=k, random=random) # create prompt as input - return self.injection(question, top_k_triples) + # when injecting, the question should be string + return self.injection(question[0], top_k_triples) diff --git a/kaping/entity_verbalization.py b/kaping/entity_verbalization.py index 06d676b..eb74c42 100644 --- a/kaping/entity_verbalization.py +++ b/kaping/entity_verbalization.py @@ -28,34 +28,34 @@ def _extract_triplets(self, text: str): """ triplets = [] - relation, subject, relation, object_ = '', '', '', '' - text = text.strip() - current = 'x' - for token in text.replace("", "").replace("", "").replace("", "").split(): - if token == "": - current = 't' - if relation != '': - triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()}) - relation = '' - subject = '' - elif token == "": - current = 's' - if relation != '': - triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()}) - object_ = '' - elif token == "": - current = 'o' - relation = '' - else: - if current == 't': - subject += ' ' + token - elif current == 's': - object_ += ' ' + token - elif current == 'o': - relation += ' ' + token - if subject != '' and relation != '' and object_ != '': - triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()}) - return triplets + relation, subject, relation, object_ = '', '', '', '' + text = text.strip() + current = 'x' + for token in text.replace("", "").replace("", "").replace("", "").split(): + if token == "": + current = 't' + if relation != '': + triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()}) + relation = '' + subject = '' + elif token == "": + current = 's' + if relation != '': + triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()}) + object_ = '' + elif token == "": + current = 'o' + relation = '' + else: + if current == 't': + subject += ' ' + token + elif current == 's': + object_ += ' ' + token + elif current == 'o': + relation += ' ' + token + if subject != '' and relation != '' and object_ != '': + triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()}) + return triplets def text_relation(self, text: str): diff --git a/kaping/model.py b/kaping/model.py index 0f64445..19d2fd6 100644 --- a/kaping/model.py +++ b/kaping/model.py @@ -30,7 +30,8 @@ def pipeline(config, question: str, device=-1): knowledge_triples.extend(verbalizer(entity, entity_title)) # entity injection as final prompt as input - prompt = injector(question, knowledge_triples, k=config.k, random=config.random, no_knowledge=config.no_knowledge) + # quetion should be a list of string + prompt = injector([question], knowledge_triples, k=config.k, random=config.random, no_knowledge=config.no_knowledge) return prompt diff --git a/run.py b/run.py index 789707a..dce4492 100644 --- a/run.py +++ b/run.py @@ -42,7 +42,7 @@ def main(): results.append(qa_pair) # evaluate to calculate the accuracy - evaluated.append(evaluated(qa_pair.answer, predicted_answer)) + evaluated.append(evaluate(qa_pair.answer, predicted_answer)) msg = "" if args.no_knowledge: