diff --git a/interact.py b/interact.py index e30ade0..2c546cf 100644 --- a/interact.py +++ b/interact.py @@ -45,8 +45,8 @@ def set_args(): parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False, help="重复惩罚参数,若生成的对话重复性较高,可适当提高该参数") # parser.add_argument('--seed', type=int, default=None, help='设置种子用于生成随机数,以使得训练的结果是确定的') - parser.add_argument('--max_len', type=int, default=25, help='每个utterance的最大长度,超过指定长度则进行截断') - parser.add_argument('--max_history_len', type=int, default=3, help="dialogue history的最大长度") + parser.add_argument('--max_len', type=int, default=250, help='每个utterance的最大长度,超过指定长度则进行截断') + parser.add_argument('--max_history_len', type=int, default=1, help="dialogue history的最大长度") parser.add_argument('--no_cuda', action='store_true', help='不使用GPU进行预测') return parser.parse_args() @@ -110,6 +110,25 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf') return logits +region_words = ['保险', '人寿保险', '平安', '医疗保险', '汽车保险', '医疗险', '家财险', '意外险', '保险公司', '重疾险', '保险费用', + '人寿', '寿险', '保险费', '险', '人身保险', '理赔', '平安保险', '保费', '车险', '保险金', '被保险人', '投保人', + '保险费率', '财产保险', '保单', '续保', '投保', '核保', '医保', + # 特定保险begin + '综合意外', '鹏城保', '百万家财', '水滴保', '福禄鑫尊', '国寿瑞鑫', '鑫裕金', '鑫尊宝', '国寿福', '同佑e生', '金佑人生', + '金福合家欢', '重庆渝惠保', '国寿康宁', '泰康贴心保', '新冠隔离津贴', '火车隔离津贴', '外卖准时宝', '泰康', '悟空保', '南充充惠保', + '春城惠民保', '太平福禄御禧', '太平洋金佑', '国华金如意', '医疗保障', '医疗保健', + # 特定保险end + # 特定术语begin + '告知义务', '说明义务', '现金价值', '犹豫期', '宽限期', '条款', + # 特定术语end + '重大疾病', + # '年金', + # '退休计划', + # '养老金', + '免赔', '保障期', '退保', '保额', '受益人', '可以保', '能保么', + ] + + def main(): args = set_args() logger = create_logger(args) @@ -139,12 +158,30 @@ def main(): try: text = input("user:") # text = "你好" + if args.save_samples_path: samples_file.write("user:{}\n".format(text)) text_ids = tokenizer.encode(text, add_special_tokens=False) history.append(text_ids) input_ids = [tokenizer.cls_token_id] # 每个input以[CLS]为开头 + #先做领域识别, 在关键字里面的为保险领域的 + in_region = False + for region_word in region_words: + if not text.__contains__(region_word): + pass + else: + in_region = True + break + if not in_region: + response_text = "这方面的问题,安安不太明白,要不您换个问法再试试,或许安安就能明白啦!" + response = tokenizer.encode(response_text, add_special_tokens=False) + history.append(response) + print("chatbot:" + "".join(response_text)) + if args.save_samples_path: + samples_file.write("chatbot:{}\n".format("".join(response_text))) + continue + for history_id, history_utr in enumerate(history[-args.max_history_len:]): input_ids.extend(history_utr) input_ids.append(tokenizer.sep_token_id) diff --git a/train.py b/train.py index e6d3453..83b7616 100644 --- a/train.py +++ b/train.py @@ -134,8 +134,8 @@ def load_dataset(logger, args): # test # input_list_train = input_list_train[:24] # input_list_val = input_list_val[:24] - print(f'train {len(input_list_train)} {input_list_train}') - print(f'valid {len(input_list_val)} {input_list_val}') + #print(f'train {len(input_list_train)} {input_list_train}') + #print(f'valid {len(input_list_val)} {input_list_val}') train_dataset = MyDataset(input_list_train, args.max_len) val_dataset = MyDataset(input_list_val, args.max_len) @@ -380,7 +380,10 @@ def main(): # 创建日志对象 logger = create_logger(args) # 当用户使用GPU,并且GPU可用时 - args.cuda = torch.cuda.is_available() and not args.no_cuda + #args.cuda = torch.cuda.is_available() and not args.no_cuda + args.cuda = True + print(f'====>cuda:{torch.cuda.is_available()}') + print(f'====>args no_cuda:{args.no_cuda} cuda {args.cuda}') device = 'cuda:0' if args.cuda else 'cpu' args.device = device logger.info('using device:{}'.format(device))