Skip to content

Commit

Permalink
Merge pull request stanfordnlp#707 from jettro/issue-699-upgrade-weav…
Browse files Browse the repository at this point in the history
…iate

fix for weaviate outdated client version. Issue 699
  • Loading branch information
arnavsinghvi11 authored Mar 25, 2024
2 parents a8b6e40 + 9499dfa commit 2dacce4
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 19 deletions.
30 changes: 21 additions & 9 deletions dspy/retrieve/weaviate_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

try:
import weaviate
import weaviate.classes as wvc
from weaviate.collections.classes.grpc import HybridFusion
except ImportError:
raise ImportError(
"The 'weaviate' extra is required to use WeaviateRM. Install it with `pip install dspy-ai[weaviate]`",
Expand All @@ -22,6 +24,9 @@ class WeaviateRM(dspy.Retrieve):
weaviate_collection_name (str): The name of the Weaviate collection.
weaviate_client (WeaviateClient): An instance of the Weaviate client.
k (int, optional): The default number of top passages to retrieve. Defaults to 3.
weaviate_collection_text_key (str, optional): The key in the collection with the content. Defaults to content.
weaviate_alpha (float, optional): The alpha value for the hybrid query. Defaults to 0.5.
weaviate_fusion_type (wvc.HybridFusion, optional): The fusion type for the query. Defaults to RELATIVE_SCORE.
Examples:
Below is a code snippet that shows how to use Weaviate as the default retriver:
Expand All @@ -44,16 +49,20 @@ class WeaviateRM(dspy.Retrieve):

def __init__(self,
weaviate_collection_name: str,
weaviate_client: weaviate.Client,
weaviate_client: weaviate.WeaviateClient,
k: int = 3,
weaviate_collection_text_key: Optional[str] = "content",
weaviate_alpha: Optional[float] = 0.5,
weaviate_fusion_type: Optional[HybridFusion] = HybridFusion.RELATIVE_SCORE,
):
self._weaviate_collection_name = weaviate_collection_name
self._weaviate_client = weaviate_client
self._weaviate_collection_text_key = weaviate_collection_text_key
self._weaviate_alpha = weaviate_alpha
self._weaviate_fusion_type = weaviate_fusion_type
super().__init__(k=k)

def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int]) -> dspy.Prediction:
def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None) -> dspy.Prediction:
"""Search with Weaviate for self.k top passages for query
Args:
Expand All @@ -72,14 +81,17 @@ def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int]) ->
queries = [q for q in queries if q]
passages = []
for query in queries:
results = self._weaviate_client.query\
.get(self._weaviate_collection_name, [self._weaviate_collection_text_key])\
.with_hybrid(query=query)\
.with_limit(k)\
.do()
collection = self._weaviate_client.collections.get(self._weaviate_collection_name)
results = collection.query.hybrid(query=query,
limit=k,
alpha=self._weaviate_alpha,
fusion_type=self._weaviate_fusion_type,
return_metadata=wvc.query.MetadataQuery(
distance=True, score=True),
)

results = results["data"]["Get"][self._weaviate_collection_name]
parsed_results = [result[self._weaviate_collection_text_key] for result in results]
parsed_results = [result.properties[self._weaviate_collection_text_key] for result in results.objects]
passages.extend(dotdict({"long_text": d}) for d in parsed_results)

# Return type not changed, needs to be a Prediction object. But other code will break if we change it.
return passages
35 changes: 27 additions & 8 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ chromadb = ["chromadb~=0.4.14"]
qdrant = ["qdrant-client~=1.6.2", "fastembed~=0.1.0"]
marqo = ["marqo"]
pinecone = ["pinecone-client~=2.2.4"]
weaviate = ["weaviate-client~=3.26.1"]
weaviate = ["weaviate-client~=4.5.4"]
docs = [
"sphinx>=4.3.0",
"furo>=2023.3.27",
Expand Down Expand Up @@ -96,7 +96,7 @@ fastembed = { version = "^0.1.0", optional = true }
marqo = { version = "*", optional = true }
qdrant-client = { version = "^1.6.2", optional = true }
pinecone-client = { version = "^2.2.4", optional = true }
weaviate-client = { version = "^3.26.1", optional = true }
weaviate-client = { version = "^4.5.4", optional = true }
sphinx = { version = ">=4.3.0", optional = true }
furo = { version = ">=2023.3.27", optional = true }
docutils = { version = "<0.17", optional = true }
Expand Down

0 comments on commit 2dacce4

Please sign in to comment.