-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
39 lines (31 loc) · 1.38 KB
/
test.py
File metadata and controls
39 lines (31 loc) · 1.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from transformers import BertModel, BertConfig
from transformers import BertTokenizer
import numpy as np
from scipy.spatial.distance import cosine
def normalize(x):
return x / np.linalg.norm(x)
def cosine_similarity(x, y):
return 1 - cosine(x, y)
model = BertModel.from_pretrained("output/checkpoint-2000")
tokenizer = BertTokenizer.from_pretrained("output/checkpoint-2000")
model.eval()
s1 = tokenizer("一种分布式光储系统",return_tensors='pt')
s2 = tokenizer("一种分布式系统",return_tensors='pt')
s3 = tokenizer("一种洗涤剂系统",return_tensors='pt')
res1 = normalize((model(**s1).pooler_output).squeeze(0).detach().numpy())
res2 = normalize((model(**s2).pooler_output).squeeze(0).detach().numpy())
res3 = normalize((model(**s3).pooler_output).squeeze(0).detach().numpy())
print(cosine_similarity(res1, res2))
print(cosine_similarity(res1, res3))
#sentence_transformer加载
# from sentence_transformers import models,SentenceTransformer
# bert = models.Transformer('output/checkpoint-2000')
# pooler = models.Pooling(bert.get_word_embedding_dimension())
# normalize = models.Normalize()
# model = SentenceTransformer(modules=[bert, pooler, normalize])
# model.eval()
# s1 = model.encode("一种分布式光储系统")
# s2 = model.encode("一种分布式系统")
# s3 = model.encode("一种洗涤剂系统")
# print(cosine_similarity(s1, s2))
# print(cosine_similarity(s1, s3))