Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions interact.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def set_args():
parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False,
help="重复惩罚参数,若生成的对话重复性较高,可适当提高该参数")
# parser.add_argument('--seed', type=int, default=None, help='设置种子用于生成随机数,以使得训练的结果是确定的')
parser.add_argument('--max_len', type=int, default=25, help='每个utterance的最大长度,超过指定长度则进行截断')
parser.add_argument('--max_history_len', type=int, default=3, help="dialogue history的最大长度")
parser.add_argument('--max_len', type=int, default=250, help='每个utterance的最大长度,超过指定长度则进行截断')
parser.add_argument('--max_history_len', type=int, default=1, help="dialogue history的最大长度")
parser.add_argument('--no_cuda', action='store_true', help='不使用GPU进行预测')
return parser.parse_args()

Expand Down Expand Up @@ -110,6 +110,25 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
return logits


region_words = ['保险', '人寿保险', '平安', '医疗保险', '汽车保险', '医疗险', '家财险', '意外险', '保险公司', '重疾险', '保险费用',
'人寿', '寿险', '保险费', '险', '人身保险', '理赔', '平安保险', '保费', '车险', '保险金', '被保险人', '投保人',
'保险费率', '财产保险', '保单', '续保', '投保', '核保', '医保',
# 特定保险begin
'综合意外', '鹏城保', '百万家财', '水滴保', '福禄鑫尊', '国寿瑞鑫', '鑫裕金', '鑫尊宝', '国寿福', '同佑e生', '金佑人生',
'金福合家欢', '重庆渝惠保', '国寿康宁', '泰康贴心保', '新冠隔离津贴', '火车隔离津贴', '外卖准时宝', '泰康', '悟空保', '南充充惠保',
'春城惠民保', '太平福禄御禧', '太平洋金佑', '国华金如意', '医疗保障', '医疗保健',
# 特定保险end
# 特定术语begin
'告知义务', '说明义务', '现金价值', '犹豫期', '宽限期', '条款',
# 特定术语end
'重大疾病',
# '年金',
# '退休计划',
# '养老金',
'免赔', '保障期', '退保', '保额', '受益人', '可以保', '能保么',
]


def main():
args = set_args()
logger = create_logger(args)
Expand Down Expand Up @@ -139,12 +158,30 @@ def main():
try:
text = input("user:")
# text = "你好"

if args.save_samples_path:
samples_file.write("user:{}\n".format(text))
text_ids = tokenizer.encode(text, add_special_tokens=False)
history.append(text_ids)
input_ids = [tokenizer.cls_token_id] # 每个input以[CLS]为开头

#先做领域识别, 在关键字里面的为保险领域的
in_region = False
for region_word in region_words:
if not text.__contains__(region_word):
pass
else:
in_region = True
break
if not in_region:
response_text = "这方面的问题,安安不太明白,要不您换个问法再试试,或许安安就能明白啦!"
response = tokenizer.encode(response_text, add_special_tokens=False)
history.append(response)
print("chatbot:" + "".join(response_text))
if args.save_samples_path:
samples_file.write("chatbot:{}\n".format("".join(response_text)))
continue

for history_id, history_utr in enumerate(history[-args.max_history_len:]):
input_ids.extend(history_utr)
input_ids.append(tokenizer.sep_token_id)
Expand Down
9 changes: 6 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def load_dataset(logger, args):
# test
# input_list_train = input_list_train[:24]
# input_list_val = input_list_val[:24]
print(f'train {len(input_list_train)} {input_list_train}')
print(f'valid {len(input_list_val)} {input_list_val}')
#print(f'train {len(input_list_train)} {input_list_train}')
#print(f'valid {len(input_list_val)} {input_list_val}')
train_dataset = MyDataset(input_list_train, args.max_len)
val_dataset = MyDataset(input_list_val, args.max_len)

Expand Down Expand Up @@ -380,7 +380,10 @@ def main():
# 创建日志对象
logger = create_logger(args)
# 当用户使用GPU,并且GPU可用时
args.cuda = torch.cuda.is_available() and not args.no_cuda
#args.cuda = torch.cuda.is_available() and not args.no_cuda
args.cuda = True
print(f'====>cuda:{torch.cuda.is_available()}')
print(f'====>args no_cuda:{args.no_cuda} cuda {args.cuda}')
device = 'cuda:0' if args.cuda else 'cpu'
args.device = device
logger.info('using device:{}'.format(device))
Expand Down