Skip to content

Commit 675bfa1

Browse files
authored
Update retrieval based classification README.md (PaddlePaddle#3322)
* Update retrieval based classification README.md * Revert predict.py * Update cpu predict script * restore gpu config
1 parent c64ed99 commit 675bfa1

File tree

21 files changed

+336
-154
lines changed

21 files changed

+336
-154
lines changed

applications/text_classification/hierarchical/retrieval_based/README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
|—— base_model.py # 语义索引模型基类
3838
|—— train.py # In-batch Negatives 策略的训练主脚本
3939
|—— model.py # In-batch Negatives 策略核心网络结构
40-
|—— ann_util.py # Ann 建索引库相关函数
4140
4241
|—— recall.py # 基于训练好的语义索引模型,从召回库中召回给定文本的相似文本
4342
|—— evaluate.py # 根据召回结果和评估集计算评估指标
@@ -167,7 +166,7 @@ unzip baike_qa_category.zip
167166

168167
### 单机单卡训练/单机多卡训练
169168

170-
这里采用单机多卡方式进行训练,通过如下命令,指定 GPU 0,1,2,3 卡;如果采用单机单卡训练,只需要把`--gpus`参数设置成单卡的卡号即可。
169+
这里采用单机多卡方式进行训练,通过如下命令,指定 GPU 0,1 卡;如果采用单机单卡训练,只需要把`--gpus`参数设置成单卡的卡号即可。
171170

172171
如果使用CPU进行训练,则需要吧`--gpus`参数去除,然后吧`device`设置成cpu即可,详细请参考train.sh文件的训练设置
173172

@@ -176,7 +175,7 @@ unzip baike_qa_category.zip
176175
```
177176
root_path=inbatch
178177
data_path=data
179-
python -u -m paddle.distributed.launch --gpus "0,1,2,3" \
178+
python -u -m paddle.distributed.launch --gpus "0,1" \
180179
train.py \
181180
--device gpu \
182181
--save_dir ./checkpoints/${root_path} \

applications/text_classification/hierarchical/retrieval_based/ann_util.py

Lines changed: 0 additions & 55 deletions
This file was deleted.

applications/text_classification/hierarchical/retrieval_based/data.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,49 @@
1313
# limitations under the License.
1414

1515
import os
16+
17+
import hnswlib
18+
import numpy as np
1619
import paddle
1720
from paddlenlp.utils.log import logger
1821

1922

23+
def build_index(corpus_data_loader, model, output_emb_size, hnsw_max_elements,
24+
hnsw_ef, hnsw_m):
25+
26+
index = hnswlib.Index(space='ip',
27+
dim=output_emb_size if output_emb_size > 0 else 768)
28+
29+
# Initializing index
30+
# max_elements - the maximum number of elements (capacity). Will throw an exception if exceeded
31+
# during insertion of an element.
32+
# The capacity can be increased by saving/loading the index, see below.
33+
#
34+
# ef_construction - controls index search speed/build speed tradeoff
35+
#
36+
# M - is tightly connected with internal dimensionality of the data. Strongly affects memory consumption (~M)
37+
# Higher M leads to higher accuracy/run_time at fixed ef/efConstruction
38+
index.init_index(max_elements=hnsw_max_elements,
39+
ef_construction=hnsw_ef,
40+
M=hnsw_m)
41+
42+
# Controlling the recall by setting ef:
43+
# higher ef leads to better accuracy, but slower search
44+
index.set_ef(hnsw_ef)
45+
46+
# Set number of threads used during batch search/construction
47+
# By default using all available cores
48+
index.set_num_threads(16)
49+
logger.info("start build index..........")
50+
all_embeddings = []
51+
for text_embeddings in model.get_semantic_embedding(corpus_data_loader):
52+
all_embeddings.append(text_embeddings.numpy())
53+
all_embeddings = np.concatenate(all_embeddings, axis=0)
54+
index.add_items(all_embeddings)
55+
logger.info("Total index number:{}".format(index.get_current_count()))
56+
return index
57+
58+
2059
def create_dataloader(dataset,
2160
mode='train',
2261
batch_size=1,

applications/text_classification/hierarchical/retrieval_based/export_model.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,25 @@
3232
help="The path of model parameter in static graph to be saved.")
3333
parser.add_argument("--output_emb_size", default=0,
3434
type=int, help="output_embedding_size")
35+
parser.add_argument("--model_name_or_path", default='rocketqa-zh-dureader-query-encoder',
36+
type=str, help='The pretrained model used for training')
3537
args = parser.parse_args()
3638
# yapf: enable
3739

3840
if __name__ == "__main__":
3941
# If you want to use ernie1.0 model, plesace uncomment the following code
40-
pretrained_model = AutoModel.from_pretrained(
41-
"rocketqa-zh-dureader-query-encoder")
42-
tokenizer = AutoTokenizer.from_pretrained(
43-
"rocketqa-zh-dureader-query-encoder")
42+
pretrained_model = AutoModel.from_pretrained(args.model_name_or_path)
43+
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
4444
model = SemanticIndexBaseStatic(pretrained_model,
4545
output_emb_size=args.output_emb_size)
4646

4747
if args.params_path and os.path.isfile(args.params_path):
4848
state_dict = paddle.load(args.params_path)
4949
model.set_dict(state_dict)
5050
print("Loaded parameters from %s" % args.params_path)
51-
51+
else:
52+
raise ValueError(
53+
"Please set --params_path with correct pretrained model file")
5254
model.eval()
5355
# Convert to static graph with specific input description
5456
model = paddle.jit.to_static(

applications/text_classification/hierarchical/retrieval_based/predict.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
help="Select which device to train model, defaults to gpu.")
4646
parser.add_argument("--pad_to_max_seq_len", action="store_true",
4747
help="Whether to pad to max seq length.")
48+
parser.add_argument("--model_name_or_path", default='rocketqa-zh-dureader-query-encoder',
49+
type=str, help='The pretrained model used for training')
4850
args = parser.parse_args()
4951
# yapf: enable
5052

@@ -77,8 +79,7 @@ def predict(model, data_loader):
7779
if __name__ == "__main__":
7880
paddle.set_device(args.device)
7981

80-
tokenizer = AutoTokenizer.from_pretrained(
81-
"rocketqa-zh-dureader-query-encoder")
82+
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
8283
trans_func = partial(convert_example,
8384
tokenizer=tokenizer,
8485
max_seq_length=args.max_seq_length,
@@ -101,8 +102,7 @@ def predict(model, data_loader):
101102
batch_size=args.batch_size,
102103
batchify_fn=batchify_fn,
103104
trans_fn=trans_func)
104-
pretrained_model = AutoModel.from_pretrained(
105-
"rocketqa-zh-dureader-query-encoder")
105+
pretrained_model = AutoModel.from_pretrained(args.model_name_or_path)
106106
model = SemanticIndexBase(pretrained_model,
107107
output_emb_size=args.output_emb_size)
108108
if args.params_path and os.path.isfile(args.params_path):

applications/text_classification/hierarchical/retrieval_based/recall.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@
6363
type=int, help="Recall number for each query from Ann index.")
6464
parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu",
6565
help="Select which device to train model, defaults to gpu.")
66+
parser.add_argument("--model_name_or_path", default='rocketqa-zh-dureader-query-encoder',
67+
type=str, help='The pretrained model used for training')
6668
args = parser.parse_args()
6769
# yapf: enable
6870

@@ -71,8 +73,7 @@
7173
rank = paddle.distributed.get_rank()
7274
if paddle.distributed.get_world_size() > 1:
7375
paddle.distributed.init_parallel_env()
74-
tokenizer = AutoTokenizer.from_pretrained(
75-
'rocketqa-zh-dureader-query-encoder')
76+
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
7677
trans_func = partial(convert_corpus_example,
7778
tokenizer=tokenizer,
7879
max_seq_length=args.max_seq_length)
@@ -82,8 +83,7 @@
8283
Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype="int64"
8384
), # text_segment
8485
): [data for data in fn(samples)]
85-
pretrained_model = AutoModel.from_pretrained(
86-
"rocketqa-zh-dureader-query-encoder")
86+
pretrained_model = AutoModel.from_pretrained(args.model_name_or_path)
8787
model = SemanticIndexBase(pretrained_model,
8888
output_emb_size=args.output_emb_size)
8989
model = paddle.DataParallel(model)
@@ -106,7 +106,12 @@
106106
trans_fn=trans_func)
107107
# Need better way to get inner model of DataParallel
108108
inner_model = model._layers
109-
final_index = build_index(args, corpus_data_loader, inner_model)
109+
final_index = build_index(corpus_data_loader,
110+
inner_model,
111+
output_emb_size=args.output_emb_size,
112+
hnsw_max_elements=args.hnsw_max_elements,
113+
hnsw_ef=args.hnsw_ef,
114+
hnsw_m=args.hnsw_m)
110115
text_list, text2similar_text = gen_text_file(args.similar_text_pair_file)
111116
query_ds = MapDataset(text_list)
112117
query_data_loader = create_dataloader(query_ds,

applications/text_classification/hierarchical/retrieval_based/scripts/train.sh

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,21 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
# GPU training
216
root_path=inbatch
317
data_path=data
4-
python -u -m paddle.distributed.launch --gpus "0,1,2,3" \
18+
python -u -m paddle.distributed.launch --gpus "0,1" \
519
train.py \
620
--device gpu \
721
--save_dir ./checkpoints/${root_path} \

applications/text_classification/hierarchical/retrieval_based/train.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from model import SemanticIndexBatchNeg
3131
from data import read_text_pair, convert_example, create_dataloader, gen_id2corpus, gen_text_file, convert_corpus_example
3232
from data import convert_label_example
33-
from ann_util import build_index
33+
from data import build_index
3434

3535
# yapf: disable
3636
parser = argparse.ArgumentParser()
@@ -62,19 +62,16 @@
6262
parser.add_argument('--log_steps', type=int, default=10,
6363
help="Inteval steps to print log")
6464
parser.add_argument("--train_set_file", type=str,
65-
default='./recall/train.csv',
65+
default='./data/train.txt',
6666
help="The full path of train_set_file.")
67-
parser.add_argument("--dev_set_file", type=str,
68-
default='./recall/dev.csv',
69-
help="The full path of dev_set_file.")
7067
parser.add_argument("--margin", default=0.2, type=float,
7168
help="Margin beteween pos_sample and neg_samples")
7269
parser.add_argument("--scale", default=30, type=int,
7370
help="Scale for pair-wise margin_rank_loss")
74-
parser.add_argument("--corpus_file", type=str, default='./recall/corpus.csv',
71+
parser.add_argument("--corpus_file", type=str, default='./data/label.txt',
7572
help="The full path of input file")
7673
parser.add_argument("--similar_text_pair_file", type=str,
77-
default='./recall/dev.csv',
74+
default='./data/dev.txt',
7875
help="The full path of similar text pair file")
7976
parser.add_argument("--recall_result_dir", type=str, default='./recall_result_dir',
8077
help="The full path of recall result file to save")
@@ -113,7 +110,12 @@ def evaluate(model, corpus_data_loader, query_data_loader, recall_result_file,
113110
text_list, id2corpus):
114111
# Load pretrained semantic model
115112
inner_model = model._layers
116-
final_index = build_index(args, corpus_data_loader, inner_model)
113+
final_index = build_index(corpus_data_loader,
114+
inner_model,
115+
output_emb_size=args.output_emb_size,
116+
hnsw_max_elements=args.hnsw_max_elements,
117+
hnsw_ef=args.hnsw_ef,
118+
hnsw_m=args.hnsw_m)
117119
query_embedding = inner_model.get_semantic_embedding(query_data_loader)
118120
with open(recall_result_file, 'w', encoding='utf-8') as f:
119121
for batch_index, batch_query_embedding in enumerate(query_embedding):

applications/text_classification/multi_class/retrieval_based/README.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
|—— base_model.py # 语义索引模型基类
3030
|—— train.py # In-batch Negatives 策略的训练主脚本
3131
|—— model.py # In-batch Negatives 策略核心网络结构
32-
|—— ann_util.py # Ann 建索引库相关函数
3332
3433
|—— recall.py # 基于训练好的语义索引模型,从召回库中召回给定文本的相似文本
3534
|—— evaluate.py # 根据召回结果和评估集计算评估指标
@@ -147,7 +146,7 @@ unzip webtext2019zh_qa.zip
147146

148147
### 单机单卡训练/单机多卡训练
149148

150-
这里采用单机多卡方式进行训练,通过如下命令,指定 GPU 0,1,2,3 卡;如果采用单机单卡训练,只需要把`--gpus`参数设置成单卡的卡号即可。
149+
这里采用单机多卡方式进行训练,通过如下命令,指定 GPU 0,1 卡;如果采用单机单卡训练,只需要把`--gpus`参数设置成单卡的卡号即可。
151150

152151
如果使用CPU进行训练,则需要吧`--gpus`参数去除,然后吧`device`设置成cpu即可,详细请参考train.sh文件的训练设置
153152

@@ -156,7 +155,7 @@ unzip webtext2019zh_qa.zip
156155
```
157156
root_path=inbatch
158157
data_path=data
159-
python -u -m paddle.distributed.launch --gpus "0,1,2,3" \
158+
python -u -m paddle.distributed.launch --gpus "0,1" \
160159
train.py \
161160
--device gpu \
162161
--save_dir ./checkpoints/${root_path} \
@@ -172,7 +171,7 @@ python -u -m paddle.distributed.launch --gpus "0,1,2,3" \
172171
--recall_result_file "recall_result.txt" \
173172
--train_set_file ${data_path}/train.txt \
174173
--corpus_file ${data_path}/label.txt \
175-
--similar_text_pair ${data_path}/dev.txt \
174+
--similar_text_pair_file ${data_path}/dev.txt \
176175
--evaluate True
177176
```
178177

applications/text_classification/multi_class/retrieval_based/ann_util.py

Lines changed: 0 additions & 55 deletions
This file was deleted.

0 commit comments

Comments
 (0)