Skip to content

Commit 23b5682

Browse files
Christy BergmanChristy Bergman
authored andcommitted
Add .py utility functions
Signed-off-by: Christy Bergman <[email protected]>
1 parent e111949 commit 23b5682

File tree

1 file changed

+98
-0
lines changed

1 file changed

+98
-0
lines changed

notebooks/text/imdb_utilities.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import numpy as np
2+
import torch
3+
from torch.nn import functional as F
4+
5+
# Output words instead of scores.
6+
def sentiment_score_to_name(score: float):
7+
if score > 0:
8+
return "Positive"
9+
elif score <= 0:
10+
return "Negative"
11+
12+
# Split data into train, valid, test.
13+
def partition_dataset(df_input, smoke_test=False):
14+
"""Splits data, assuming original, input dataframe contains 50K rows.
15+
16+
Args:
17+
df_input (pandas.DataFrame): input data frame
18+
smoke_test (boolean): if True, use smaller number of rows for testing
19+
20+
Returns:
21+
df_train, df_val, df_test (pandas.DataFrame): train, valid, test splits.
22+
"""
23+
24+
# Shuffle data and split into train/val/test.
25+
df_shuffled = df_input.sample(frac=1, random_state=1).reset_index()
26+
# Add a corpus index.
27+
columns = ['movie_index', 'text', 'label_int', 'label']
28+
df_shuffled.columns = columns
29+
30+
df_train = df_shuffled.iloc[:35_000]
31+
df_val = df_shuffled.iloc[35_000:40_000]
32+
df_test = df_shuffled.iloc[40_000:]
33+
34+
# Save train/val/test split data locally in separate files.
35+
df_train.to_csv("train.csv", index=False, encoding="utf-8")
36+
df_val.to_csv("val.csv", index=False, encoding="utf-8")
37+
df_test.to_csv("test.csv", index=False, encoding="utf-8")
38+
39+
return df_shuffled, df_train, df_val, df_test
40+
41+
# Take as input a user query and conduct semantic vector search using the query.
42+
def mc_search_imdb(query, retriever, milvus_collection, search_params, top_k,
43+
milvus_client=False, COLLECTION_NAME = 'movies'):
44+
45+
# Embed the query using same embedding model used to create the Milvus collection.
46+
query_embeddings = torch.tensor(retriever.encode(query))
47+
# Normalize embeddings to unit length.
48+
query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
49+
# Quick check if embeddings are normalized.
50+
norms = np.linalg.norm(query_embeddings, axis=1)
51+
assert np.allclose(norms, 1.0, atol=1e-5) == True
52+
# Convert the embeddings to list of list of np.float32.
53+
query_embeddings = list(map(np.float32, query_embeddings))
54+
55+
# Run semantic vector search using your query and the vector database.
56+
# Assemble results.
57+
distances = []
58+
texts = []
59+
movie_indexes = []
60+
labels = []
61+
if milvus_client:
62+
# MilvusClient search API call slightly different.
63+
results = milvus_collection.search(
64+
COLLECTION_NAME,
65+
data=query_embeddings,
66+
search_params=search_params,
67+
output_fields=["movie_index", "chunk", "label"],
68+
limit=top_k,
69+
consistency_level="Eventually",
70+
)
71+
# Results returned from MilvusClient are in the form list of lists of dicts.
72+
for result in results[0]:
73+
distances.append(result['distance'])
74+
texts.append(result['entity']['chunk'])
75+
movie_indexes.append(result['entity']['movie_index'])
76+
labels.append(result['entity']['label'])
77+
else:
78+
# Milvus server search API call.
79+
results = milvus_collection.search(
80+
data=query_embeddings,
81+
anns_field="vector",
82+
param=search_params,
83+
output_fields=["movie_index", "chunk", "label"],
84+
limit=top_k,
85+
consistency_level="Eventually"
86+
)
87+
# Assemble results from Milvus server.
88+
distances = results[0].distances
89+
for result in results[0]:
90+
texts.append(result.entity.get("chunk"))
91+
movie_indexes.append(result.entity.get("movie_index"))
92+
labels.append(result.entity.get("label"))
93+
94+
# Assemble all the results in a zipped list.
95+
formatted_results = list(zip(distances, movie_indexes, texts, labels))
96+
97+
return formatted_results
98+

0 commit comments

Comments
 (0)