by Sunlly
2022.6.16
论文:《Dense Passage Retrieval for Open-Domain Question Answering》 2020
github代码:https://github.com/facebookresearch/DPR
- 创建新容器:
docker run -itd -m 10g -v [宿主机目录]:[容器目录] --gpus all --name [容器名] --shm-size="2g" pytorch/pytorch
- 创建目录和git clone项目
git clone [email protected]:facebookresearch/DPR.git
cd DPR
- 安装依赖:
pip install .
- 运行:
python train_dense_encoder.py \
train_datasets=[list of train datasets, comma separated without spaces] \
dev_datasets=[list of dev datasets, comma separated without spaces] \
train=biencoder_local \
output_dir={path to checkpoints dir}
使用 nq 的数据集作为测试:
python train_dense_encoder.py \
train_datasets=nq-train.json \
dev_datasets=nq-dev.json \
train=biencoder_local \
output_dir=test_nq_20220616
修改 train_dense_encoder.py 的 args 后可以直接运行和调试。
- 运行 train_dense_encoder.py
报错:OSError: [E050] Can't find model 'en_core_web_sm'. It doesn't seem to be a Python package or a valid path to a data directory.
原因:包下载有问题
解决: 尝试 python -m spacy download en_core_web_sm 无效;
安装:pip install en_core_web_sm-3.3.0-py3-none-any.whl
安装成功
1. 将 wikisql 训练集、测试集做处理,训练 DPR
处理数据集,将 前面筛除过的 test 集 中的table,出现过的table 保留,没有出现过的table删除(wikisql_remove_tables_out100.py),形成新的 test.tables.jsonl(test.tables_remove_out100.jsonl)
test.tables.jsonl total: 5230个表 count: 4628 remove: 602
修改 wikisql_generatedata 的代码,用 WikiSQL 生成符合 dpr 训练数据形式的数据集,格式如下:
[
{
"dataset": "nq_dev_psgs_w100",
"question": "who sings does he love me with reba",
"answers": [
"Linda Davis"
],
"positive_ctxs": [
{
"title": "Does He Love You",
"text": "Does He Love You \"Does He Love You\" is a song written by Sandy Knox and Billy Stritch, and recorded as a duet by American country music artists Reba McEntire and Linda Davis. ",
"score": 13.394315,
"title_score": 0,
"passage_id": "11828866"
},
{
"title": "Red Sandy Spika dress of Reba McEntire",
"text": "Red Sandy Spika dress of Reba McEntire American recording artist Reba McEntire wore a sheer red dress to the 1993 Country Music Association Awards ceremony on September 29, 1993.",
"score": 12.924647,
"title_score": 0,
"passage_id": "15632586"
}
],
"negative_ctxs": [
{
"title": "Cormac McCarthy",
"text": "chores of the house, Lee was asked by Cormac to also get a day job so he could focus on his novel writing. ",
"score": 0,
"title_score": 0,
"passage_id": "2145653"
},
{
"title": "Pragmatic Sanction of 1549",
"text": "one heir, Charles effectively united the Netherlands as one entity. ",
"score": 0,
"title_score": 0,
"passage_id": "2271902"
}
]
},
{
"dataset": "nq_dev_psgs_w100",
"question": "who sings does he love me with reba",
"answers": [
"Linda Davis"
],
...
}
]
根据代码的 github 网站上所述,其实score 是没有在模型中用到的,但是包含在了数据集内。
** 不确定bm25的负样本是如何融入到模型的训练中去的
写了一个处理数据集的代码,将原数据集/筛选过后的数据集修改成DPR 的输入数据集:(data_generate_dpr.py),自己定义了标签名:
neg_tables=[]
for i in range(neg_num):
random_idx=random.randint(1,len(tables_ids))-1
# print("random_idx:",random_idx)
random_table_id=tables_ids[random_idx]
while random_table_id == origin_table_id:
random_idx=random.randint(1,len(tables_ids))-1
random_table_id=tables_ids[random_idx]
random_table_content=tables[random_table_id]
neg_sample={}
neg_sample["table_id"]=random_table_id
neg_sample["content"]=random_table_content
neg_tables.append(neg_sample)
item_out={}
item_out["dataset"]="wikisql_"+phase
item_out["question"]=raw_sample['question']
item_out["answer"]=[raw_sample["sql"]]
item_out["positive_ctxs"]=[pos_sample]
item_out["negative_ctxs"]=neg_tables
item_out=json.dumps(item_out)
f.write(item_out)
对于 negative_ctxs,在 table 中随机取 20 个,对应训练时的 batch
**(但是不太清楚dpr 原代码的实现中是如论文所说的 in-batch 的训练,即负样本取自 数据集中的 negative_ctxs ,还是同 batch )
修改dpr 原代码中,标签名载入部分的代码:
class JsonQADataset(Dataset):
...
def create_passage(ctx: dict):
return BiEncoderPassage(
## 改标签名 by Sunlly
# normalize_passage(ctx["text"]) if self.normalize else ctx["text"],
# ctx["title"],
normalize_passage(ctx["content"]) if self.normalize else ctx["content"],
ctx["table_id"],
)
运行 train_dense_encoder.py 开始训练模型
问题: 由于table的 token 过长,输出警告:意思是会将超过的 table token截去,只保留前面的
**(论文上也提到了这一点,将 过长的 passage 分成多段,并且训练样本 passage 的长度最好一样)
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
解决:加入头文件
import transformers
transformers.logging.set_verbosity_error()
其含义:只报告错误信息,将详细程度设置为ERROR级别。
** 但是,截取表格后,表格内容信息不完整,会不会影响表格性能?(希望模型能学习到如果 question 和内容有token 相同,他们的相似度会更高 这一点)
** 技巧:dpr 通过log 打印了输出日志,效果比 nohup 的好,后续可以学一下。
在wikisql上训练后,用 test 做评估,结果不太理想:
仅在 test 的 pos/neg 样本上做测试, correct prediction ratio 4541/13216 ~ 0.343599
2. 下载官方的模型检查点,生成 embeddings(将 passage 编码为向量形式)
由于上面的模型效果不好,决定先用官方训练好的模型先把代码跑通。
下载官方检查点:
"checkpoint.retriever.single.nq.bert-base-encoder": {
"s3_url": "https://dl.fbaipublicfiles.com/dpr/checkpoint/retriever/single/nq/hf_bert_base.cp",
"original_ext": ".cp",
"compressed": False,
"desc": "Biencoder weights trained on NQ data and HF bert-base-uncased model",
},
修改 generate_dense_embeddings.py 中的配置参数:
def main(cfg: DictConfig):
## add args by Sunlly
# cfg.model_file="/nlp_files/DPR/outputs/2022-06-16/08-59-29/test_nq_20220616/dpr_biencoder.3"
cfg.model_file="/nlp_files/DPR/model/hf_bert_base.cp"
# cfg.ctx_src="/nlp_files/DPR/nq-dev-small.json"
# cfg.ctx_src="/nlp_files/DPR/downloads/data/wikipedia_split/psgs_w100.tsv"
cfg.ctx_src="dpr_wiki"
cfg.out_file="/nlp_files/DPR/embeddings/nq"
修改 default_dources.yml:
dpr_wiki:
_target_: dpr.data.retriever_data.CsvCtxSrc
# file: data.wikipedia_split.psgs_w100
file: "/nlp_files/DPR/downloads/data/wikipedia_split/psgs_w100.tsv"
id_prefix: 'wiki:'
dpr_nq:
## 用于编码的text 形式
_target_: dpr.data.retriever_data.CsvQASrc
# file: /nlp_files/DPR/nq-dev-small.json
file: /nlp_files/DPR/nq-test.csv
# id_prefix: 'nq-small:'
注意: cfg.ctx_src 不能直接指定路径,需要在 default_dources.yml 中去找。如果是 dpr_wiki 并指定 file: data.wikipedia_split.psgs_w100
会自动下载 psgs_w100.tsv(12G)。没有用 nq 的数据集,此处相当于先用官网的例子跑通。
tsv 文件例子:
**后续可以按照 tsv 的格式构建 tables 的数据集。
修改了 gen_embs.yaml,用处不大。
修改了原代码中的 end_idx:
# end_idx = start_idx + shard_size
end_idx=start_idx +200
因为只想跑个例子,原数据集的 passage 太多了。所以相当于只取了前 200条生成 ctx 向量。
开始跑代码:
python generate_dense_embeddings.py
生成的向量结果在 embeddings 中,是个打不开的二进制文件
3. 根据生成好了的 ctx 向量,编码问题做检索
代码在 dense_retriever.py。
修改配置:
@hydra.main(config_path="conf", config_name="dense_retriever")
def main(cfg: DictConfig):
cfg = setup_cfg_gpu(cfg)
## add args by Sunlly
cfg.model_file="/nlp_files/DPR/model/hf_bert_base.cp"
cfg.qa_dataset="nq_test" #/nlp_files/DPR/conf/datasets/retriever_default.yaml
cfg.ctx_datatsets=["dpr_wiki"] ## need [] is a dict
cfg.encoded_ctx_files=["/nlp_files/DPR/embeddings/nq_0"] ## need [] is a dict
cfg.out_file="/nlp_files/DPR/retriever_validation"
retriever_default.yaml:
nq_test:
_target_: dpr.data.retriever_data.CsvQASrc
# file: data.retriever.qas.nq-test
file: "/nlp_files/DPR/nq-test.csv"
遇到问题:
IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
修改代码:顺利解决
# max_vector_len = max(q_t.size(1) for q_t in batch_tensors)
# min_vector_len = min(q_t.size(1) for q_t in batch_tensors)
max_vector_len = max(q_t.size(0) for q_t in batch_tensors)
min_vector_len = min(q_t.size(0) for q_t in batch_tensors)
run code:
python dense_retriever.py
结果:
找到了前 100 个匹配的 passage,由于数据集和 question 其实是对不上的,所以无法评估正确率。
检索的参数具备 score,可以用于后续的 rerank。
可以在 dense_retriever.yml 中设置检索的 passage 数量:
1. 下载 nq-table 模型检查点,根据生成好了的 ctx 向量,编码 nq-table
nq_table 数据集 tables 部分示例:
{
"columns": [{
"text": ""
}, {
"text": ""
}, {
"text": "Born"
}, {
"text": "Residence"
}, {
"text": "Occupation"
}, {
"text": "Years\u00a0active"
}, {
"text": "Height"
}, {
"text": "Television"
}, {
"text": "Children"
}],
"rows": [{
"cells": [{
"text": "Lesley Joseph"
}, {
"text": "Joseph in Cardiff, Wales, May 2011"
}, {
"text": "Lesley Diana Joseph 14 October 1945 (age\u00a072) Finsbury Park, Haringey, London, England"
}, {
"text": "Hampstead, North London"
}, {
"text": "Broadcaster, actress"
}, {
"text": "1969\u2013present"
}, {
"text": "5\u00a0ft 2\u00a0in (1.57\u00a0m)"
}, {
"text": "Birds of a Feather"
}, {
"text": "2"
}]
}],
"tableId": "Lesley Joseph_A1D55A57012E3362",
"documentTitle": "Lesley Joseph",
"documentUrl": "https://en.wikipedia.org//w/index.php?title=Lesley_Joseph&oldid=843506707"
}
取 100 个 table 作为示例:修改 conf/ctx_sources/table_sources.yaml:
nq_table_raw:
_target_: dpr.data.retriever_data.JsonlNQTablesCtxSrc
# file: "/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/tables.jsonl"
file: "/nlp_files/nqt-retrieval/datasets/nq_table/tables_small.jsonl"
id_prefix: 'nqt:'
修改 generate_embedding.py 的参数配置:
def main(cfg: DictConfig):
## by Sunlly
cfg.model_file="/nlp_files/nqt-retrieval/checkpoint/retriever/single-adv-hn/nq/bert-base-encoder.cp"
cfg.ctx_src="nq_table_raw"
cfg.out_file="/nlp_files/nqt-retrieval/embeddings/nq_table"
run code:
python generate_embedding.py
问题:
ctx_src = hydra.utils.instantiate(cfg.ctx_sources[cfg.ctx_src])
报错, 经过检查后 ctx_src 是正确的。
解决:注释:
from dpr.data.biencoder_data import (
BiEncoderTable,
get_nq_table_files,
get_processed_table,
# get_processed_table_wiki,
# get_processed_table_wqt,
)
问题:
报错:
if self.id_prefix: sample_id = self.id_prefix + sample['id']
sample not has 'id' 解决: 换 ctx_scr , 从 nq_table 到 nq_table_raw
问题: biencoder_data.py
if max_cell_num: cell_list = cell_list[: max_cell_num]
发生异常: TypeError slice indices must be integers or None or have an index method
处理第一个表就报错。
解决:注释该句
问题:
table_data.py 中报错:
for cell in cell_list:
text = cell['text'].strip()
TypeError: tuple indices must be integers or slices, not str
解决: 换成引用而不是切片
for cell in cell_list:
# text = cell['text'].strip()
text = cell.text.strip()
不知道为啥会有这个问题,可能是 python 版本不一样的原因。 下面又遇到错误,同样替换。
# row_ids.extend([cell['row_idx'] for _ in cell_token_ids])
# col_ids.extend([cell['col_idx'] for _ in cell_token_ids])
row_ids.extend([cell.row_idx for _ in cell_token_ids])
col_ids.extend([cell.col_idx for _ in cell_token_ids])
前 100 个 tables 处理完成:
另外,运行 process_table.py 可以处理表格,比如增加分隔符之类的。
2. 基于 nq-table 的编码做检索
从 nq_table_test 中取了 6 个问题做测试(nq_table_test_small.jsonl),其中最后一个比较长。
run code :
python dense_retrieval.py
也遇到同上问题(cell[...]),同理修改解决。
结果:(nq_table_test_small_result.jsonl)
{
"question": "where does the brazos river start and stop",
"answers": [
"Llano Estacado",
"Gulf of Mexico"
],
"ctxs": [
{
"id": "nqt:Amazon River_4E0AC7C7AE20D8EE",
"title": "amazon river",
"text": "amazon river name | country | coordinates | image | hidenotes . allpahuayo - mishana national reserve | peru | 3 \u00b0 56\u2032s 73 \u00b0 33\u2032w\ufeff / \ufeff3.933 | [ 51 ] . amacayacu national park | colombia | 3 \u00b0 29\u2032s 72 \u00b0 12\u2032w\ufeff / \ufeff3.483 | [ 52 ] . amazo\u0302nia national park | brazil | 4 \u00b0 26\u2032s 56 \u00b0 50\u2032w\ufeff / \ufeff4.433 | [ 53 ] . anavilhanas national park | brazil | 2 \u00b0 23\u2032s 60 \u00b0 55\u2032w\ufeff / \ufeff2.383 | [ 54 ] .",
"score": "65.01775",
"has_answer": false
},
{
"id": "nqt:Assiniboine River_75DCC362724ADC33",
"title": "assiniboine river",
"text": "assiniboine river location | peak flow , 1995 ( m3 / s ) | mean flow , april ' 95 ( m3 / s ) | mean flow , may ' 95 ( m3 / s ) | max flow , date ( m3 / s ) . russell | 360 may 4 | 34.2 | 46.3 | 504 april 29 , 1922 . brandon | 566 april 26 | 81.1 | 104.0 | 651 may 7 , 1923 . headingley | 300 april 20 | 115.0 | 142.0 | 614 april 27 , 1916 .",
"score": "64.74672",
"has_answer": false
},
{
"id": "nqt:Indus River_78C7B2189DD8401",
"title": "indus river",
"text": "indus river [ show ] v t e hydrology of pakistan | [ show ] v t e hydrology of pakistan.1 . lakes | ansoo baghsar banjosa borith chitta katha dudipatsar hadero . rivers | indus astore bara basol braldu bunhar chenab dasht . coastal | indian ocean arabian sea gulf of oman . categories | lakes rivers .",
"score": "60.91795",
"has_answer": false
},
...
由于只取了200 个表格,虽然检索出来的表格可能对不上,但是总体而言内容还是比较相关。
3. 完整的 nq-table 数据集做编码和检索
修改:conf/ctx_sources/table_sources.yaml 中的 file 为完整的 nq_table:
nq_table_raw:
_target_: dpr.data.retriever_data.JsonlNQTablesCtxSrc
# file: "/mnt/zhiruow/hitab/table-retrieval/datasets/nq_table/tables.jsonl"
file: "/nlp_files/nqt-retrieval/datasets/nq_table/tables.jsonl"
id_prefix: 'nqt:'
修改 generate_embeddings.py 中的 out_file,
cfg.model_file="/nlp_files/nqt-retrieval/checkpoint/retriever/single-adv-hn/nq/bert-base-encoder.cp"
cfg.ctx_src="nq_table_raw"
cfg.out_file="/nlp_files/nqt-retrieval/embeddings/nq_table"
运行:
nohup python \-u generate_embeddings.py
nq_table:共 168989 个表格,编码约用3.5 h。(近17万个表)
检索:
dense_retrieval.py
结果:
4. 如何将 wikisql_table 以平铺或 加分隔符的方式,载入 dpr embedding?
对wikisql 中 train/dev/test 中所有的表综合在一起,构建一个 wikisql_tables 集。
- 剔除执行结果为空的 train/dev/test 的 question
question | remove_none_answer | |
---|---|---|
train | 56355 | 52032 |
dev | 8421 | 7764 |
test | 15878 | 14599 |
- 将原始的各tables文件(WikiSQL/data/train.tables.jsonl、dev.tables.jsonl、test.jsonl)全部合起来,形成wikisql_tables_process_bm25.jsonl。
bm25 的数据集处理过程如下:(table_process_process_bm25.py)
for 循环读取三个 pharse()拼接 page_title(如果有) /header/ rows,增加 column 项。为了便于后续的区分,为id 增加前缀:train_/dev_/test_
然后全部合起来,形成总的 wikisql_tables 集(wikisql_tables_process_bm25.jsonl)
origin_table_num | |
---|---|
train | 18585 |
dev | 2716 |
test | 5230 |
total: | 26531 |
全部导入 elasticsearch 中的 index=wikisql_tables(upload_table_data_test.py),结果:
{'count': 26531, '_shards': {'total': 1, 'successful': 1, 'skipped': 0, 'failed': 0}}
导入成功。
3. 基于剔除执行结果为空的 train/dev/test 的 question, 剔除三个数据集中 bm25 top100(数字值得商榷)的question 和table。
修改 wikisql_elastic_python_test_remove_out100_tables.py ,top_k 设置成 200,bm25 为默认参数值,增加 phase 并运行(hisresult 增加 phase),结果:
train: 筛除 8698 个
test:筛除 2543个问题
remove_none_answer | remove_bm25_out_ | |
---|---|---|
train | 52032 | 43334 |
dev | 7764 | 6358 |
test | 14599 | 12056 |
作为最后的 open_wikisql_question 集
去除筛选后保留的问题中没有包含到的表格 (wikisql_remove_tables_out200.py)
最后保留下来的表格有 22319 个,共移除 4212个。
- 创建 dpr 的表格数据。分两种情况:直接平铺和有分隔符。
table_process_total_process_dpr.py
修改:实现生成两种情况的table数据集
has_delimiters=True
26531 个表格-->22319个表格
同上面一样的方法筛除 table:
将以上筛除后的表格载入 elasticsearch
"data_preprocess/wikisql_tables_remove_out200_dpr_with_delimiter.jsonl"
- 生成 dpr 的 wikisql 训练数据形式的数据集
采用的策略是,neg:随机取61个 table,soft_hard_neg:检索 bm25 ,在top20 中随机取2个,pos:只有1个。 (data_generate_traindata_dpr_out200.py)
问题:检索的时候遇到:
elastic_transport.ConnectionError: Connection error caused by: ConnectionError(Connection error caused by: ProtocolError(('Connection aborted.', ConnectionResetError(104, 'Connection reset by peer'))))
貌似是检索太频繁导致的?
增加一个 sleep:
hard_neg_tables=get_tables(question,origin_table_id)
if cnt%10==0:
print("cnt: ",cnt)
time.sleep(0.1)
elastic_transport.ConnectionError: Connection error caused by: ConnectionError(Connection error caused by: NewConnectionError(<urllib3.connection.HTTPConnection object at 0x7f67d216c750>: Failed to establish a new connection: [Errno 111] Connection refused))
重启 es:
su esuser
./bin/elasticsearch
生成结果:带/不带 delimiter 的 training/testing ,共4个文件,将其挪到 host4的 DPR 的 wikisql_data 文件夹中
DPR/wikisql_data/wikisql_dpr_training_data_with_delimiter.json、/nlp_files/DPR/wikisql_data/wikisql_dpr_testing_data_with_delimiter.json
- 用 raw dpr 训练 embedding 生成 dpr 的 wikisql 训练数据形式的数据集
问题:train 集的loss 下降, test 集的loss 上升?学习率也在上升。 eval 效果极差。
询问师姐说是过拟合。
后续将batch_size=1,调了 neg 的数量,训练集43334,测试集12048,平均loss=35.6,准确率 5/12048,效果极差
neg 的数量作用在哪,感觉没有什么效果。
在40000多的 trainset 间隔10个取了2000个问题的子集做训练,在test 上间隔10个取了200个子集做测评,再进一步降低了学习率。
后续testloss 正常下降了,eval 的效果(猜测是64 中取1)也达到了 54%
继续降低学习率做训练,同时将 drop_out从 0.1 调到了0.2,max_length 从 256到512。显存需要增加,导致 batch_size=1 都跑不起了。将 conf/train/biencoder_local.yaml的other_negatives调到 20。一个epoch 需40min 训练, 5min 测试。20epoch 自动停止。loss 分别为:
epoch | train | test | test_acc |
---|---|---|---|
0 | 2.617569 | 2.219180 | 113/200 ~ 0.565000 |
1 | 1.258608 | 1.907552 | 126/200 ~ 0.630000 |
2 | 1.210599 | 1.711939 | **137/200 ~ 0.685000 ** |
3 | 1.280452 | 1.791353 | 132/200 ~ 0.660000 |
4 | 1.126368 | 1.906223 | 131/200 ~ 0.655000 |
5 | 1.113602 | 1.953571 | 127/200 ~ 0.635000 |
6 | 0.917539 | 2.063476 | 126/200 ~ 0.630000 |
7 | 0.948463 | 2.347768 | 124/200 ~ 0.620000 |
8 | 0.937369 | 2.728052 | 115/200 ~ 0.575000 |
9 | 0.601034 | 3.316344 | 107/200 ~ 0.535000 |
10 | 0.658453 | 4.087913 | 99/200 ~ 0.495000 |
... | ... | ... | ... |
19 | 0.489357 | 7.573456 | 82/200 ~ 0.410000 |
learning_rate: 5.0e-07 最好模型为:outputs/2022-07-01/13-09-01/test_wikisql_20220701/dpr_biencoder.2
典型的再往后就是越过拟合。后续两个方案:1. 调 drop_out,再试(0.2->0.1,0.2->0.3)2. 增大训练数据集的数据量(2000->4000,2000->6000)
用2000样本将drop_out 从0.2->0.1, learning_rate: 5.0e-07, 最高 62%,后续测试集上性能下降。总体效果没有原来好。
用2000样本 drop_out=0.3 训练, learning_rate: 5.0e-07。train loss=1.2+ ,test loss=2.187767,epoch 1结束 时准确率60%。后续暂停了训练。
用8000 样本,drop_out=0.2, learning_rate: 4.0e-07, eval_per_epoch: 4(每隔2000样本测试1次。) other_negatives: 20。epoch 刚一半时效果最高:NLL Validation: loss = 1.679720. correct prediction ratio 123/200 ~ 0.615000。后续准确率极速下降。
用40000训练样本+500测试样本做训练,调learning_rate: 3.0e-07,gradient_accumulation_steps: 8,dropout: 0.2,目前epoch1 中的最好效果 311/500 ~ 0.622000, 再往后性能下降。感觉学习率有点低。
2022.7.6
尝试了多次,模型及其容易过拟合。通过加载检查点的训练,最高 test acc=71%。发现当降低 other_negatives 的数量从 20->10 的时候效果好,说明模型可能不擅长分类 other_neg 样本。
通过分析可能是数据集的问题,重新生成hard_neg=1 other_neg=20 的例子,取 train=3000,test=500,在之前的检查点上重新做训练。
还有一个原因,可能是 问题的cfg.max_length太长了,导致模型过拟合。
5. wikisql_table embedding 并做检索得到结果,对检索结果进行评估。
embedding:
修改 default_sources.yaml:
dpr_wikisql:
_target_: dpr.data.retriever_data.JsonlWikiSQLCtxSrc
file: "/nlp_files/DPR/wikisql_data/wikisql_tables_remove_out200_dpr_with_delimiter.jsonl"
修改 dpr/data/retriever_data.py,自己写了一个加载数据的类,借用BiEncoderPassage,用 column_meta 当做其的 title 项。
# by Sunlly
class JsonlWikiSQLCtxSrc(RetrieverData):
def __init__(
self,
file: str,
# id_col: int = 0,
# text_col: int = 1,
# meta_col: int = 2,
id_col: str = "id",
text_col: str = "content",
meta_col: str = "column_meta",
id_prefix: str = None,
normalize: bool = False,
):
super().__init__(file)
self.text_col = text_col
self.meta_col = meta_col
self.id_col = id_col
self.id_prefix = id_prefix
self.normalize = normalize
def load_data_to(self, ctxs: Dict[object, BiEncoderPassage]):
super().load_data()
logger.info("Reading file %s", self.file)
with jsonlines.open(self.file) as r:
for row in r:
sample_id = row[self.id_col]
passage = row[self.text_col].strip('"')
ctxs[sample_id] = BiEncoderPassage(passage, row[self.meta_col])
修改 generate_embedding.py 使title 项(实际上的 column_meta)不加载入embedding
def gen_ctx_vectors:
# batch_token_tensors = [
# tensorizer.text_to_tensor(ctx[1].text, title=ctx[1].title if insert_title else None) for ctx in batch
# ]
# by Sunlly dont load column_meta
batch_token_tensors = [
tensorizer.text_to_tensor(ctx[1].text, title="" if insert_title else None) for ctx in batch
]
修改 config,运行generate_dense_embeddings.py
# end_idx = start_idx + shard_size
# by Sunlly
end_idx=start_idx +200
def main(cfg: DictConfig):
cfg.model_file="/nlp_files/DPR/outputs/2022-07-06/02-17-21/test_wikisql_20220706/dpr_biencoder.3"
cfg.ctx_src="dpr_wikisql"
cfg.out_file="/nlp_files/DPR/embeddings/wikisql_tables"
取前 200 个生成 embedding:
运行dense_retriever.py,暂时先用的 nq_test 作为 QA:
cfg.model_file="/nlp_files/DPR/outputs/2022-07-06/02-17-21/test_wikisql_20220706/dpr_biencoder.3"
cfg.qa_dataset="nq_test"
cfg.ctx_datatsets=["dpr_wikisql"] ## need [] is a dict in default_sources.yaml
cfg.encoded_ctx_files=["/nlp_files/DPR/embeddings/wikisql_tables_0"] ## need [] is a dict
cfg.out_file="/nlp_files/DPR/wikisql_retriever_validation"
[
{
"question": "who got the first nobel prize in physics",
"answers": [
"Wilhelm Conrad R\u00f6ntgen"
],
"ctxs": [
{
"id": "train_1-10236830-4",
"title": [
[
"Nomination",
"string",
null
],
[
"Actors Name",
"string",
null
],
[
"Film Name",
"string",
null
],
[
"Director",
"string",
null
],
[
"Country",
"string",
null
]
],
"text": "StozharyNomination, Actors Name, Film Name, Director, Country. Best Actor in a Leading Role, Yuriy Dubrovin, Okraina, Pyotr Lutsik, Ukraine. Best Actor in a Leading Role, Zurab Begalishvili, Zdes Rassvet, Zaza Urushadze, Georgia. Best Actress in a Leading Role, Galina Bokashevskaya, Totalitarian Romance, Vyacheslav Sorokin, Russia. Best Actor in a Supporting Role, Vsevolod Shilovskiy, Barhanov and his Bodyguard, Valeriy Lanskoy, Russia. Best Actor in a Supporting Role, Dragan Nikoli\u0107, Barrel of Gunpowder, Goran Paskaljevic, Serbia. Best Actress in a Supporting Role, Zora Manojlovic, Rane, Srdjan Dragojevic, Serbia. Best Debut, Agnieszka W\u0142odarczyk, Sara, Maciej \u015alesicki, Poland. ",
"score": "78.355064",
"has_answer": false
},
{...},
...
}
}
]
修改 QA:
wikisql_test:
_target_: dpr.data.retriever_data.JsonlWikiSQLQASrc
file: "/nlp_files/DPR/wikisql_data/test_remove_out200_table_in_total_tables.jsonl"
修改 retriever_data.py
# by Sunlly
class JsonlWikiSQLQASrc(QASrc):
def __init__(
self,
file: str,
selector: DictConfig = None,
question_attr: str = "question",
answers_attr: str = "sql",
id_attr: str = "table_id",
special_query_token: str = None,
query_special_suffix: str = None,
):
super().__init__(file, selector, special_query_token, query_special_suffix)
self.question_attr = question_attr
self.answers_attr = answers_attr
self.id_attr = id_attr
def load_data(self):
super().load_data()
data = []
with jsonlines.open(self.file, mode="r") as jsonl_reader:
for jline in jsonl_reader:
question = jline[self.question_attr]
answers = jline[self.answers_attr] if self.answers_attr in jline else []
id = None
if self.id_attr in jline:
id = jline[self.id_attr]
data.append(QASample(self._process_question(question), id, answers))
self.data = data
期间修改了一下代码,使json不要换行:
def save_results:
with open(out_file, "w") as writer:
## by Sunlly, for not 换行
writer.write(json.dumps(merged_data))
# writer.write(json.dumps(merged_data, indent=4) + "\n")
logger.info("Saved results * scores to %s", out_file)
官方没有专门针对 retriever 的评估,自己写了一个评估代码: /nlp_files/DPR/wikisql_pdr_evaluate.py
结果很不理想,远远没有基线高:
(就像一个学渣平时学的不好,考试前担心自己挂科,拼命复习抱佛脚,结果考完试成绩出来发现果然挂科了的感觉。)
和师姐交流后决定修改 loss,对于20个 other_neg 将原来的 e^{sim(q,p^-)}扩大三倍,差不多相当于60个 other_neg,
class BiEncoderNllLoss(object):
def calc:
scores = self.get_scores(q_vectors, ctx_vectors)
# by Sunlly to expand the effect of other_neg
n=scores.shape[1]
# print(n)
neg_time_tensor=torch.zeros(1,n)
for i in range(n):
if i>=2:
neg_time_tensor[0][i]=neg_score_times
neg_time_tensor_gpu=neg_time_tensor.to(device='cuda')
scores=scores+neg_time_tensor_gpu
print(scores)
##
然后将grad_accumulation 设置为32,lr_rate 设置为1e-5,drapout 设置为1.5,重新训练。 尽量向论文的超参数接近。训练的比较慢,由于loss累积, 现在的loss 相当于原来loss 的 loss/32,
增加保存最好检查点:
## by Sunlly save best checkpoint
def _save_checkpoint(self, scheduler, epoch: int, offset: int, best_checkpoint=False) -> str:
cfg = self.cfg
model_to_save = get_model_obj(self.biencoder)
cp = os.path.join(cfg.output_dir, cfg.checkpoint_file_name + "." + str(epoch))
## by Sunlly save best checkpoint
if best_checkpoint:
cp = os.path.join(cfg.output_dir, cfg.checkpoint_file_name + "." + str(epoch)+"best")
##
def validate_and_save(self, epoch: int, iteration: int, scheduler):
if validation_loss<0.5:
best_cp_name = self._save_checkpoint(scheduler, epoch, iteration,best_checkpoint=True)
logger.info("Save New Best validation checkpoint %s", best_cp_name)
周末跑了两天,没有什么进展。
怀疑可能是表格长度差异太大了,导致模型学习效果不好。
对train 的问题做了筛选,gap=5 取sample,基本保证一个表对应一个问题,然后将设other_neg=60,length=150,过长的给截断,基本能保证各sample 的table 长度一致。
又将loss 给改回来,从ln(3)到ln(1)。
继续训练。这半个月基本可以说是没有进展,心情焦灼,挫败感十足,感觉要 emo了。
7. 对比 bm25 在 wikisql_tables 集,不用EG(纯 bm25)和用 EG 筛选后的效果
test集,用纯bm25的检索效果: bm25_res: k1= 1.2 , b= 0.75 , top_k= 1 ,count= 4595 , hit_accuracy= 0.38113802256138024 bm25_res: k1= 1.2 , b= 0.75 , top_k= 5 ,count= 7300 , hit_accuracy= 0.6055076310550763 bm25_res: k1= 1.2 , b= 0.75 , top_k= 10 ,count= 8279 , hit_accuracy= 0.6867120106171201 bm25_res: k1= 1.2 , b= 0.75 , top_k= 20 ,count= 9225 , hit_accuracy= 0.7651791639017916 bm25_res: k1= 1.2 , b= 0.75 , top_k= 50 ,count= 10404 , hit_accuracy= 0.8629727936297279 bm25_res: k1= 1.2 , b= 0.75 , top_k= 100 ,count= 11235 , hit_accuracy= 0.9319011280690113 bm25_res: k1= 1.2 , b= 0.75 , top_k= 200 ,count= 11788 , hit_accuracy= 0.977770404777704 hit200理论上是应该 100%,但没有达到,是因为排除out_200的问题后又筛选了一遍表格,对结果造成了影响。 不过问题应该不大。后续也不准备再更新表格。
然后选一个在 test 集上效果较好的HydraNet模型,在模型的代码中改成用 cpu 做预测(显存tyh在用)
class HydraTorch(BaseModel):
def __init__(self, config):
self.config = config
self.model = HydraNet(config)
if torch.cuda.device_count() > 1:
self.model = nn.DataParallel(self.model)
# self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.device = torch.device("cpu")
在 data_preprocess/wikitest_out200.jsonl 上运行:
候选1 /nlp_files/HydraNet-WikiSQL/output/20220630_033334/model_1.pt dev 82.9 ,wikitest_out100.jsonl 83.8 ===HydraNet=== sel_acc: 0.976692103516921 agg_acc: 0.9265925680159257 wcn_acc: 0.9774386197743862 wcc_acc: 0.9418546781685467 wco_acc: 0.9716323822163239 wcv_acc: 0.9472461844724619
===HydraNet+EG=== sel_acc: 0.976692103516921 agg_acc: 0.9265925680159257 wcn_acc: 0.9810053085600531 wcc_acc: 0.9703881884538819 wco_acc: 0.9759455872594559 wcv_acc: 0.9732913072329131
无EG: { "ex_accuracy": 0.8809721300597213, "lf_accuracy": 0.8340245520902455 }
+EG: { "ex_accuracy": 0.9273390842733908, "lf_accuracy": 0.8752488387524884 } 候选2 /nlp_files/HydraNet-WikiSQL/output/20220628_022212/model_2.pt ===HydraNet=== sel_acc: 0.9747013934970139 agg_acc: 0.9197080291970803 wcn_acc: 0.9752820172528202 wcc_acc: 0.9391174518911746 wco_acc: 0.9692269409422694 wcv_acc: 0.9477438619774387
===HydraNet+EG=== sel_acc: 0.9747013934970139 agg_acc: 0.9197080291970803 wcn_acc: 0.9800099535500996 wcc_acc: 0.9698075646980756 wco_acc: 0.9742037159920371 wcv_acc: 0.9724618447246185
无 EG { "ex_accuracy": 0.8779031187790312, "lf_accuracy": 0.8283012607830126 } +EG { "ex_accuracy": 0.9215328467153284, "lf_accuracy": 0.8673689449236894 }
落后于 候选1
降低学习率继续训练 hydraNet.候选3:learning_rate 6e-6.无EG,
/nlp_files/HydraNet-WikiSQL/output/20220702_083211/model_2.pt
[wikidev.jsonl, epoch 2] overall:83.5, agg:91.2, sel:97.8, wn:98.7, wc:95.8, op:99.1, val:97.4 [wikitest_out200.jsonl, epoch 2] overall:83.8, agg:92.5, sel:97.6, wn:98.0, wc:94.7, op:99.4, val:96.8
===HydraNet=== sel_acc: 0.9757796947577969 agg_acc: 0.9250995355009953 wcn_acc: 0.9796781685467817 wcc_acc: 0.946914399469144 wco_acc: 0.9732913072329131 wcv_acc: 0.9526376907763769
===HydraNet+EG=== sel_acc: 0.9757796947577969 agg_acc: 0.9250995355009953 wcn_acc: 0.9816688785666888 wcc_acc: 0.9718812209688122 wco_acc: 0.9762773722627737 wcv_acc: 0.9742866622428666
{ "ex_accuracy": 0.8868613138686131, "lf_accuracy": 0.8386695421366954 }
{ "ex_accuracy": 0.9273390842733908, "lf_accuracy": 0.8753317850033179 }
总体来说候选3的效果全胜候选1,是目前效果最好的模型
接下来,用该模型作为检索阶段的 EG 在bm25 算法检索的基础上做评估。
cp -r output/20220702_083211 ../hydranet/output
修改代码,在bm25 检索出200个 hits 的基础上,用hydranet 转SQL 后执行。计算检索效果的指标.(/nlp_files/hydranet/wikisql_elastic_bm25_and_hydranet_predict_final_tables.py)
res_top_k=[1,5,10,20,50,100],有EG/无EG。对于有EG ,根据执行结果重新排序hits得到candidate_hits。
def rerank_hits(exe_res,hits,k):
将 exe_res=True 的筛选出来作为 candidate_hits,
如果不够 k 个,再按exe_res=False 中 的score 从高到低取,直到满k 个。
对于检索出的表 id,分train/dev/test 三种情况在不同的 db 上执行。将检索结果写在test_retrieve_final_table_bm25.jsonl.(用于后续测试end-to-end 的 lf 和 ex)
结果: 3210的时候: no_EG: num: 3210 ,top_k: 1 count: 1306 accuracy: 0.40685358255451715 no_EG: num: 3210 ,top_k: 5 count: 2008 accuracy: 0.6255451713395639 no_EG: num: 3210 ,top_k: 10 count: 2290 accuracy: 0.7133956386292835 no_EG: num: 3210 ,top_k: 20 count: 2538 accuracy: 0.7906542056074767 no_EG: num: 3210 ,top_k: 50 count: 2834 accuracy: 0.8828660436137071 no_EG: num: 3210 ,top_k: 100 count: 3031 accuracy: 0.9442367601246106
with_WG: num: 3210 ,top_k: 1 count: 1824 accuracy: 0.5682242990654206 with_WG: num: 3210 ,top_k: 5 count: 2574 accuracy: 0.8018691588785046 with_WG: num: 3210 ,top_k: 10 count: 2776 accuracy: 0.864797507788162 with_WG: num: 3210 ,top_k: 20 count: 2922 accuracy: 0.9102803738317757 with_WG: num: 3210 ,top_k: 50 count: 3060 accuracy: 0.9532710280373832 with_WG: num: 3210 ,top_k: 100 count: 3131 accuracy: 0.9753894080996884
可以看出 hit@1 提升 16%,hit@5 提升 18%,hit@10提升 7%,提升幅度较大。
后续再3246个例子的时候出现bug,断掉了。改好 bug 之后接着跑,但是 result 没了。 后面需要倒回来重新跑一下。
前 3246 (0-3245) no_EG: num: 12056 ,top_k: 1 count: 1317 accuracy: 0.10924021234240212 no_EG: num: 12056 ,top_k: 5 count: 2027 accuracy: 0.1681320504313205 no_EG: num: 12056 ,top_k: 10 count: 2309 accuracy: 0.19152289316522894 no_EG: num: 12056 ,top_k: 20 count: 2559 accuracy: 0.21225945587259457 no_EG: num: 12056 ,top_k: 50 count: 2867 accuracy: 0.23780690112806901 no_EG: num: 12056 ,top_k: 100 count: 3067 accuracy: 0.25439615129396154 with_WG: num: 12056 ,top_k: 1 count: 1843 accuracy: 0.1528699402786994 with_WG: num: 12056 ,top_k: 5 count: 2605 accuracy: 0.21607498341074982 with_WG: num: 12056 ,top_k: 10 count: 2808 accuracy: 0.23291307232913072 with_WG: num: 12056 ,top_k: 20 count: 2955 accuracy: 0.2451061712010617 with_WG: num: 12056 ,top_k: 50 count: 3094 accuracy: 0.256635700066357 with_WG: num: 12056 ,top_k: 100 count: 3167 accuracy: 0.26269077637690774
3246 之后(3246-12056) no_EG: num: 12056 ,top_k: 1 count: 3278 accuracy: 0.2718978102189781 no_EG: num: 12056 ,top_k: 5 count: 5273 accuracy: 0.4373755806237558 no_EG: num: 12056 ,top_k: 10 count: 5970 accuracy: 0.49518911745189115 no_EG: num: 12056 ,top_k: 20 count: 6666 accuracy: 0.5529197080291971 no_EG: num: 12056 ,top_k: 50 count: 7537 accuracy: 0.6251658925016589 no_EG: num: 12056 ,top_k: 100 count: 8168 accuracy: 0.6775049767750497 with_WG: num: 12056 ,top_k: 1 count: 5040 accuracy: 0.418049104180491 with_WG: num: 12056 ,top_k: 5 count: 7038 accuracy: 0.5837757133377571 with_WG: num: 12056 ,top_k: 10 count: 7584 accuracy: 0.6290643662906437 with_WG: num: 12056 ,top_k: 20 count: 7993 accuracy: 0.6629893828798938 with_WG: num: 12056 ,top_k: 50 count: 8384 accuracy: 0.6954213669542136 with_WG: num: 12056 ,top_k: 100 count: 8515 accuracy: 0.7062873258128732
总的结果:
no_EG: num: 12056 ,top_k: 1 count: 4595 accuracy: 0.381138022 no_EG: num: 12056 ,top_k: 5 count: 7300 accuracy: 0.6055076310 no_EG: num: 12056 ,top_k: 10 count: 8279 accuracy: 0.6867120106 no_EG: num: 12056 ,top_k: 20 count: 9225 accuracy: 0.765179163 no_EG: num: 12056 ,top_k: 50 count: 10404 accuracy: 0.8629727936297279 no_EG: num: 12056 ,top_k: 100 count: 11235 accuracy: 0.931901128
with_WG: num: 12056 ,top_k: 1 count: 6883 accuracy: 0.570919044459 with_WG: num: 12056 ,top_k: 5 count: 9643 accuracy: 0.7998506967 with_WG: num: 12056 ,top_k: 10 count: 10392 accuracy: 0.8619774386 with_WG: num: 12056 ,top_k: 20 count: 10948 accuracy: 0.90809555408 with_WG: num: 12056 ,top_k: 50 count: 11478 accuracy: 0.952057067 with_WG: num: 12056 ,top_k: 100 count: 11682 accuracy: 0.968978102
8.由于dpr 训练效果不好,并且和数据集有很大的关系,准备重新洗一遍wikisql数据集
创建一个content 由 header
用minhash算法筛除相似度 0.91 以上的表格(wikisql_table_sim.py),剩余 13254 个表(wikisql_tables_process_head_sim_out_0.9.jsonl)
用 筛除了不能执行和执行结果为空的问题集(train/dev/test),筛除表格(wikisql_remove_tables_none_answer.py),剩 12392个表。
用新的表格(wikisql_tables_process_head_sim_out_0.9_without_none_answer.jsonl) 筛除对应的问题集。(wikisql_remove_question_not_in_filted_tables.py) 结果 train 29624 dev 3236 test 5924
共:38784个问题
目前先跑第一版的模型,即没有重新划分数据集和去语境化。
先生成bm25 的数据(/nlp_files/hydranet/table_process_total_process_bm25.py),并载入 index=wikisql_tables_out_sim.(upload_table_data_test.py)
空间满了,无法创建新的es index,删除了不必要的文件和 git 仓库: rm -rf .git
生成 dpr 的初始数据集,增加分隔符。(table_process_total_process_dpr.py),生成训练和测试所需要的数据que+hard_neg1+neg60
训练,batch_size=2,accumulate__loss=64,other neg=60.max_length=150
效果一般,测试集最高 54%.
去语境化: 1.用具体的名次替换代词/名词短语(he definite NP “the copper statue” with “The Statue of Liberty”, or the abbreviated name “Meg” with “Megan “Meg” Griffin”.) 40% 2.略缩词或名词的扩展 3.删除只能在上下文中才能理解的话语标记(therefore) 3% 4.在名词短语中添加修饰语(in XXX) 5.增加修饰整个句子的短语(at xxx) 6.添加有助于显著提高可读性的背景信息(“The Eagles” with “The American rock band The Eagles.”) 10%
试了一下,手工去语境化太麻烦了,一个下午能处理100个表格就不错了,但一共有10000多张表
自动去语境化(do_decontext_by_auto.py) 将question 和table 的head_conten(page-title+section_title+header)转换成 minhash 向量,对于相似度小于 60% 的question ,增加 “in page_title” 的描述。
同时删除所有的没有 page_title的表格和对应的问题。
共 36243个问题(data_preprocess/train_without_none_answer_0.9_sim_decontext_auto.jsonl)(dev/test)
按 8:1:1.5 的比例划分 train dev test:27613,3451,5179
11715张表
重新处理,流程与上面类似,创建新的 dpr 训练数据集
生成 bm25 的时候,去除 page_title,只保留模式信息和content,保证(--)基线的准确率。 index= wikisql-tables_re_divide
结果很理想,用前10000个train 问题和 300个测试集做训练时的评估,效果最高能达到 94.44% ,检索准确率:
感觉挺高的,继续用全部的训练集做训练(29624),用 cfg.model_file="/nlp_files/DPR/outputs/2022-07-14/08-30-34/test_wikisql_20220714/dpr_biencoder.6" 评估:
接下来用 length=200 做评估, batch_size=30,训练 25轮。
------index: 5178 -------- top_ 1 hit: True, index= 0 top 5 hit: True, index= 0 top 10 hit: True, index= 0 top 20 hit: True, index= 0 top 50 hit: True, index= 0 top 100 _hit: True, index= 0 top_1_hit: 3714 accuracy= 0.7171268584668855 top_5_hit: 4543 accuracy= 0.8771963699555899 top_10_hit: 4709 accuracy= 0.9092488897470554 top_20_hit: 4847 accuracy= 0.9358949604170689 top_50_hit: 4983 accuracy= 0.9621548561498359 top_100_hit: 5067 accuracy= 0.9783742035141919
同时跑bm25,看test 评估的效果
----- 5178 , 2022:07:15 01:42:39 ------------ question: Which city is Naples Airport located in in Blue Air destinations? loaded. Data shapes: input_ids (1166, 96) input_mask (1166, 96) segment_ids (1166, 96) ===HydraNet=== model prediction start model prediction end, time elapse: 6.335601568222046 error_execute_count: 3 before_EG: hit top_k 5 count: 4 after_EG: hit top_k 5 count: 4 before_EG: hit top_k 10 count: 4 after_EG: hit top_k 10 count: 4 before_EG: hit top_k 20 count: 4 after_EG: hit top_k 20 count: 4 before_EG: hit top_k 50 count: 4 after_EG: hit top_k 50 count: 4 before_EG: hit top_k 100 count: 4 after_EG: hit top_k 100 count: 4 before_EG: hit top_k 200 count: 4 after_EG: hit top_k 200 count: 4
no_EG: num: 5179 ,top_k: 1 count: 2223 accuracy: 0.42923344274956554 no_EG: num: 5179 ,top_k: 5 count: 3139 accuracy: 0.6061015640084958 no_EG: num: 5179 ,top_k: 10 count: 3456 accuracy: 0.6673102915620777 no_EG: num: 5179 ,top_k: 20 count: 3741 accuracy: 0.7223402201197142 no_EG: num: 5179 ,top_k: 50 count: 4133 accuracy: 0.7980305078200425 no_EG: num: 5179 ,top_k: 100 count: 4432 accuracy: 0.8557636609384051 no_EG: num: 5179 ,top_k: 200 count: 4682 accuracy: 0.9040355280942267 with_WG: num: 5179 ,top_k: 1 count: 2607 accuracy: 0.5033790307009075 with_WG: num: 5179 ,top_k: 5 count: 3512 accuracy: 0.6781231898049817 with_WG: num: 5179 ,top_k: 10 count: 3796 accuracy: 0.732960030893995 with_WG: num: 5179 ,top_k: 20 count: 4017 accuracy: 0.7756323614597412 with_WG: num: 5179 ,top_k: 50 count: 4324 accuracy: 0.8349102143270901 with_WG: num: 5179 ,top_k: 100 count: 4513 accuracy: 0.8714037458968913 with_WG: num: 5179 ,top_k: 200 count: 4682 accuracy: 0.9040355280942267
9.最后的结果,wikisql_re_divide_0.56
重新洗了一般 question,sim<=0.56,用训练完的dpr模型评估test:
训练好后:
ps: 该图作废,除数少算了1!!!
重新训练的HydraNet模型(0.56)
hydraNet训练后
Model saved in path: output/20220715_065141/model_2.pt output/20220715_065141_cases model prediction start model prediction end, time elapse: 118.74542093276978 wikidev_re_divide_0.56.jsonl: overall:84.3, agg:93.5, sel:98.1, wn:97.8, wc:94.6, op:99.3, val:96.2 model prediction start model prediction end, time elapse: 179.55445551872253 wikitest_re_divide_0.56.jsonl: overall:83.8, agg:92.9, sel:97.8, wn:97.9, wc:94.7, op:99.5, val:96.0
训练后的hydranet,在 test_re_divide 0.56 上做测评 100%|████████████████████████████████████████████████████████████████████████████████| 5179/5179 [00:31<00:00, 163.73it/s] { "ex_accuracy": 0.8882023556671172, "lf_accuracy": 0.8391581386368024 }
bm25
跑加了 title 的效果(未全,到1540 就断了)
no_EG: num: 1540 ,top_k: 1 count: 951 accuracy: 0.6175324675324675 no_EG: num: 1540 ,top_k: 5 count: 1207 accuracy: 0.7837662337662338 no_EG: num: 1540 ,top_k: 10 count: 1287 accuracy: 0.8357142857142857 no_EG: num: 1540 ,top_k: 20 count: 1359 accuracy: 0.8824675324675325 no_EG: num: 1540 ,top_k: 50 count: 1424 accuracy: 0.9246753246753247 no_EG: num: 1540 ,top_k: 100 count: 1468 accuracy: 0.9532467532467532 no_EG: num: 1540 ,top_k: 200 count: 1503 accuracy: 0.9759740259740259 with_WG: num: 1540 ,top_k: 1 count: 1118 accuracy: 0.7259740259740259 with_WG: num: 1540 ,top_k: 5 count: 1350 accuracy: 0.8766233766233766 with_WG: num: 1540 ,top_k: 10 count: 1399 accuracy: 0.9084415584415585 with_WG: num: 1540 ,top_k: 20 count: 1445 accuracy: 0.9383116883116883 with_WG: num: 1540 ,top_k: 50 count: 1475 accuracy: 0.9577922077922078 with_WG: num: 1540 ,top_k: 100 count: 1490 accuracy: 0.9675324675324676 with_WG: num: 1540 ,top_k: 200 count: 1503 accuracy: 0.9759740259740259 比 dpr 高,不太能作为基线
没有加 title 的bm25:
HydraNet+DPR(EG,without EG):
no_EG: num: 5179 ,top_k: 1 count: 2939 accuracy: 0.5673745173745174 no_EG: num: 5179 ,top_k: 5 count: 3999 accuracy: 0.772007722007722 no_EG: num: 5179 ,top_k: 10 count: 4356 accuracy: 0.8409266409266409 no_EG: num: 5179 ,top_k: 20 count: 4628 accuracy: 0.8934362934362934 no_EG: num: 5179 ,top_k: 50 count: 4905 accuracy: 0.946911196911197 no_EG: num: 5179 ,top_k: 100 count: 5026 accuracy: 0.9702702702702702 with_WG: num: 5179 ,top_k: 1 count: 3986 accuracy: 0.7694980694980695 with_WG: num: 5179 ,top_k: 5 count: 4637 accuracy: 0.8951737451737452 with_WG: num: 5179 ,top_k: 10 count: 4819 accuracy: 0.9303088803088803 with_WG: num: 5179 ,top_k: 20 count: 4895 accuracy: 0.944980694980695 with_WG: num: 5179 ,top_k: 50 count: 4988 accuracy: 0.962934362934363 with_WG: num: 5179 ,top_k: 100 count: 5026 accuracy: 0.9702702702702702