-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprompts.py
165 lines (137 loc) · 6.86 KB
/
prompts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import pandas as pd
import torch
def load_raw_texts(dataset):
raw_texts_path = f'dataset/{dataset}/raw_texts.pt'
raw_texts = torch.load(raw_texts_path)
return raw_texts
def get_detailed_instruct(task_description: str, query: str) -> str:
return f'<instruct>{task_description}\n<query>{query}'
def get_detailed_example(task_description: str, query: str, response: str) -> str:
return f'<instruct>{task_description}\n<query>{query}\n<response>{response}'
class Prompt():
def __init__(self, texts2encode, labels):
# initialization
self.texts2encode = texts2encode
self.labels = labels
self.num_texts = len(self.texts2encode)
self.num_labels = len(self.labels)
# task: the meta information description of task
# task = ''
# examples: whether do few-shot or not, default: do not use examples
# examples = [{'instruct', 'query', 'response'}]
# queries: after the task, say the query
# queries = [get_detailed_instruct(task, '')]
# documents: the text to encode
# documents = ['']
def prepare_prompts(self, version):
# version
# primary: nothing
# class_aware: know the class
if version == 'primary':
self.task = self.get_primary_task()
self.examples_prefix = ''
self.queries = [get_detailed_instruct(self.task, '') + '\n' + self.texts2encode[i] for i in range(self.num_texts)]
elif version == 'class_aware':
self.task = self.get_class_aware_task()
class_description = ''
for i in range(self.num_labels):
class_description += self.labels[i] + '\n'
self.task += class_description
self.examples_prefix = ''
self.queries = [get_detailed_instruct(self.task, '') + '\n' + self.texts2encode[i] for i in range(self.num_texts)]
def get_class_aware_task(self):
raise NotImplementedError
def get_primary_task(self):
raise NotImplementedError
class Prompt_citeseer(Prompt):
def __init__(self, texts2encode, labels):
super().__init__(texts2encode, labels)
def get_class_aware_task(self):
return 'Given the description or opening text of scientific publications, classify it into one of the following 6 classes: \n'
def get_primary_task(self):
return 'Encode the description or opening text of scientific publications: \n'
class Prompt_cora(Prompt):
def __init__(self, texts2encode, labels):
super().__init__(texts2encode, labels)
def get_class_aware_task(self):
return 'Given the opening text of machine learning papers, classify it into one of the following 7 classes: \n'
def get_primary_task(self):
return 'Encode the text of machine learning papers: \n'
class Prompt_pubmed(Prompt):
def __init__(self, texts2encode, labels):
super().__init__(texts2encode, labels)
self.labels = [
'Diabetes Mellitus Experimental (animal models, cell-based experiments)',
'Diabetes Mellitus Type 1 (autoimmune condition causing absolute insulin deficiency)',
'Diabetes Mellitus Type 2 (insulin resistance and a progressive decline in insulin production)']
def get_class_aware_task(self):
return 'Given the title and abstract of scientific publications, classify it into one of the following 3 classes: \n'
def get_primary_task(self):
return 'Encode the title and abstract of scientific publications: \n'
class Prompt_wikics(Prompt):
def __init__(self, texts2encode, labels):
super().__init__(texts2encode, labels)
# pre-process the text. remove some useless texts
char1 = 'feature node. wikipedia entry name:'
char2 = 'entry content:'
new_char1 = 'entry:'
new_char2 = 'content:'
processed_list = [s.replace(char1, new_char1).replace(char2, new_char2) for s in self.texts2encode]
self.texts2encode = processed_list
def get_class_aware_task(self):
return 'Given the entry and content of wikipedia, classify it into one of the following 10 classes: \n'
def get_primary_task(self):
return 'Encode the entry and content of wikipedia: \n'
class Prompt_bookhis(Prompt):
def __init__(self, texts2encode, labels):
super().__init__(texts2encode, labels)
def get_class_aware_task(self):
return 'Given the description or title of the book, classify it into one of the following 12 classes: \n'
def get_primary_task(self):
return 'Encode the description or title of the book: \n'
class Prompt_bookchild(Prompt):
def __init__(self, texts2encode, labels):
super().__init__(texts2encode, labels)
def get_class_aware_task(self):
return 'Given the description or title of the child literature, classify it into one of the following 24 classes: \n'
def get_primary_task(self):
return 'Encode the description or title of the child literature: \n'
class Prompt_sportsfit(Prompt):
def __init__(self, texts2encode, labels):
super().__init__(texts2encode, labels)
char1 = 'The title of the item in this Sports & Fitness category is'
new_char1 = 'The title is'
processed_list = [s.replace(char1, new_char1) for s in self.texts2encode]
self.texts2encode = processed_list
def get_class_aware_task(self):
return 'Given the title of a good in sports & fitness, classify it into one of the following 13 classes: \n'
def get_primary_task(self):
return 'Encode the title of a good in sports & fitness: \n'
class Prompt_cornell(Prompt):
def __init__(self, texts2encode, labels):
super().__init__(texts2encode, labels)
def get_class_aware_task(self):
return 'Given a webpage text, classify it into one of the following 5 classes: \n'
def get_primary_task(self):
return 'Encode the webpage text: \n'
class Prompt_wisconsin(Prompt):
def __init__(self, texts2encode, labels):
super().__init__(texts2encode, labels)
def get_class_aware_task(self):
return 'Given a webpage text, classify it into one of the following 5 classes: \n'
def get_primary_task(self):
return 'Encode the webpage text: \n'
class Prompt_washington(Prompt):
def __init__(self, texts2encode, labels):
super().__init__(texts2encode, labels)
def get_class_aware_task(self):
return 'Given a webpage text, classify it into one of the following 5 classes: \n'
def get_primary_task(self):
return 'Encode the webpage text: \n'
class Prompt_texas(Prompt):
def __init__(self, texts2encode, labels):
super().__init__(texts2encode, labels)
def get_class_aware_task(self):
return 'Given a webpage text, classify it into one of the following 5 classes: \n'
def get_primary_task(self):
return 'Encode the webpage text: \n'