From fbf03e9ce8c1bf31636e8b30264a9ef317246b63 Mon Sep 17 00:00:00 2001 From: Nikhil Talreja Date: Fri, 23 Aug 2024 17:26:40 +0200 Subject: [PATCH] 615 - Add n_results for WeaviateDatabase --- src/vanna/weaviate/weaviate_vector.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/vanna/weaviate/weaviate_vector.py b/src/vanna/weaviate/weaviate_vector.py index 589aeca3..54fc8cca 100644 --- a/src/vanna/weaviate/weaviate_vector.py +++ b/src/vanna/weaviate/weaviate_vector.py @@ -26,6 +26,7 @@ def __init__(self, config=None): if config is None: raise ValueError("config is required") + self.n_results = config.get("n_results", 3) self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5") self.weaviate_api_key = config.get("weaviate_api_key") self.weaviate_url = config.get("weaviate_url") @@ -120,12 +121,12 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str: response = self._insert_data('sql', data_object, self.generate_embedding(question)) return f'{response}-sql' - def _query_collection(self, cluster_key: str, vector_input: list, return_properties: list, limit: int = 3) -> list: + def _query_collection(self, cluster_key: str, vector_input: list, return_properties: list) -> list: self.weaviate_client.connect() collection = self.weaviate_client.collections.get(self.training_data_cluster[cluster_key]) response = collection.query.near_vector( near_vector=vector_input, - limit=limit, + limit=self.n_results, return_properties=return_properties ) response_list = [item.properties for item in response.objects]