Skip to content

Commit

Permalink
validate: couple of fixes to e2e evaluation loop
Browse files Browse the repository at this point in the history
  • Loading branch information
Mehrad0711 committed Sep 24, 2021
1 parent 2dbce09 commit 395b4bd
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 81 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ models/.DS_Store
src/
workdir/
*save*/
eval_dir/*
eval_dir*/*
genieNLP-tests*

lightning_logs/
Expand Down
2 changes: 1 addition & 1 deletion genienlp/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def computeBITOD(greedy, answer, tgt_lang):
subtask_metrics_dict[subtasks[t]] = (sub_metrics, len(golds), subtask_weights[t])

# TODO how should we aggregate?
bitod_score = 0.0
bitod_score, JGA, response_bleu, api_em = 0.0, 0.0, 0.0, 0.0
weighted_num_examples = 0
for subtask, (sub_metrics, num_ex, weight) in subtask_metrics_dict.items():
if subtask == 'dst':
Expand Down
164 changes: 85 additions & 79 deletions genienlp/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import logging
import re
import sys
import time
from collections import defaultdict

import torch
Expand Down Expand Up @@ -131,10 +132,13 @@ def generate_with_seq2seq_model_for_dialogue(
contexts = []

cur_dial_id = ''
new_state_text = 'null'

device = model.device

required_slots = read_require_slots()
required_slots = {API_MAP[k]: v for k, v in required_slots.items()}
api_names = list(required_slots.keys())

for k, turn in enumerate(progress_bar(data_iterator, desc='Generating', disable=disable_progbar)):
batch_size = len(turn.example_id)
assert batch_size == 1
Expand All @@ -151,6 +155,9 @@ def generate_with_seq2seq_model_for_dialogue(
cur_dial_id = dial_id
first_turn = True
dialogue_state = {}
new_state_text = 'null'
new_knowledge_text = 'null'
active_api = None
bitod_preds[dial_id] = {"turns": defaultdict(dict), "API": defaultdict(dict)}
else:
first_turn = False
Expand All @@ -159,7 +166,7 @@ def generate_with_seq2seq_model_for_dialogue(
batch_tokens = numericalizer.convert_ids_to_tokens(turn.context.value.data, skip_special_tokens=False)
batch_context = []
# remove only beginning and trailing special tokens
# otherwise the numericalizer.sep_token added between context and question will be lost
# otherwise the sep_token added between context and question will be lost
for text in batch_tokens:
i = 0
while text[i] in special_tokens:
Expand Down Expand Up @@ -187,6 +194,7 @@ def generate_with_seq2seq_model_for_dialogue(

if first_turn:
# first turn is always dst
assert train_target == 'dst'
numericalized_turn = NumericalizedExamples(
example_id=[turn.example_id[0]],
context=SequentialField(
Expand All @@ -203,93 +211,25 @@ def generate_with_seq2seq_model_for_dialogue(
),
)
else:
required_slots = read_require_slots()
required_slots = {API_MAP[k]: v for k, v in required_slots.items()}
api_names = list(required_slots.keys())

# find train_target
if train_target == 'dst':

#### save latest response
bitod_preds[dial_id]["turns"][str(turn_id - 1)]["response"] = predictions[-1]
####

input_text = replace_match(contexts[-1], state_re, new_state_text)

## if you want to use predicted response instead of gold uncomment the following
# last_sys_pred = predictions[-1][0].strip()
# input_text = replace_match(input_text, last_system_re, last_sys_pred)

elif train_target == 'api':

lev = predictions[-1][0].strip()
state_update = span2dict(lev, api_names)
for api_name in state_update:
active_api = api_name
if api_name not in dialogue_state:
dialogue_state[api_name] = state_update[api_name]
else:
dialogue_state[api_name].update(state_update[api_name])

#### save latest state
state_to_record = copy.deepcopy(dialogue_state)
state_to_record = {r_en_API_MAP.get(k, k): v for k, v in state_to_record.items()}
bitod_preds[dial_id]["turns"][str(turn_id)]["state"] = state_to_record
####

new_state_text = state2span(dialogue_state, required_slots)

# replace gold state with predicted state
# replace state
input_text = replace_match(contexts[-1], state_re, new_state_text)

elif train_target == 'response':
# replace state
input_text = replace_match(contexts[-1], state_re, new_state_text)

bitod_preds[dial_id]["turns"][str(turn_id)]["api"] = ''

do_api_call = predictions[-1][0].strip()
if do_api_call == 'no':
# knowledge is null so just use current input
input_text = contexts[-1]
elif do_api_call == 'yes':
# do api call
api_name = active_api
if api_name in dialogue_state:
constraints = state2api(dialogue_state[api_name])

try:
msg = api.call_api(
r_en_API_MAP.get(api_name, api_name),
constraints=[constraints],
)
except Exception as e:
print(f'Error: {e}')
print(f'Failed API call with api_name: {api_name} and constraints: {constraints}')
msg = [0, 0]

domain = api_name.split(" ")[0]

knowledge = defaultdict(dict)
if int(msg[1]) <= 0:
new_knowledge_text = f'( {domain} ) Message = No item available.'
else:
# why does it only choose the first; does the same happen for training data?
knowledge[domain].update(msg[0])
new_knowledge_text = knowledge2span(knowledge)

#### save latest api results and constraints
bitod_preds[dial_id]["turns"][str(turn_id)]["api"] = new_knowledge_text
bitod_preds[dial_id]["API"][r_en_API_MAP.get(api_name, api_name)] = copy.deepcopy(constraints)
####

input_text = replace_match(contexts[-1], knowledge_re, new_knowledge_text)
input_text = replace_match(input_text, state_re, new_state_text)

else:
logger.error(
f'API call should be either yes or no but got {do_api_call}; seems model is still training, we assume a no'
)
# knowledge is null so just use current input
input_text = contexts[-1]
# replace knowledge
input_text = replace_match(input_text, knowledge_re, new_knowledge_text)

else:
raise ValueError(f'Invalid train_target: {train_target}')
Expand Down Expand Up @@ -338,11 +278,77 @@ def generate_with_seq2seq_model_for_dialogue(

predictions += batch_prediction

#### save last response
bitod_preds[dial_id]["turns"][str(turn_id)]["response"] = predictions[-1]
####
if train_target == 'dst':
# update dialogue_state
lev = predictions[-1][0].strip()
state_update = span2dict(lev, api_names)
for api_name in state_update:
active_api = api_name
if api_name not in dialogue_state:
dialogue_state[api_name] = state_update[api_name]
else:
dialogue_state[api_name].update(state_update[api_name])

#### save latest state
state_to_record = copy.deepcopy(dialogue_state)
state_to_record = {r_en_API_MAP.get(k, k): v for k, v in state_to_record.items()}
bitod_preds[dial_id]["turns"][str(turn_id)]["state"] = state_to_record
####

elif train_target == 'api':
new_knowledge_text = 'null'
constraints = {}

api_name = active_api if active_api else 'None'

do_api_call = predictions[-1][0].strip()

if do_api_call == 'yes':
# make api call if required
api_name = active_api
# do api call
if api_name in dialogue_state:
constraints = state2api(dialogue_state[api_name])

try:
msg = api.call_api(
r_en_API_MAP.get(api_name, api_name),
constraints=[constraints],
)
except Exception as e:
print(f'Error: {e}')
print(f'Failed API call with api_name: {api_name} and constraints: {constraints}')
msg = [0, 0]

domain = api_name.split(" ")[0]

knowledge = defaultdict(dict)
if int(msg[1]) <= 0:
new_knowledge_text = f'( {domain} ) Message = No item available.'
else:
# why does it only choose the first; does the same happen for training data?
knowledge[domain].update(msg[0])
new_knowledge_text = knowledge2span(knowledge)

elif do_api_call == 'no':
# do nothing
pass
else:
logger.error(
f'API call should be either yes or no but got {do_api_call}; seems model is still training so we assume it\'s a no'
)

#### save latest api results and constraints
bitod_preds[dial_id]["turns"][str(turn_id)]["api"] = new_knowledge_text
bitod_preds[dial_id]["API"][r_en_API_MAP.get(api_name, api_name)] = copy.deepcopy(constraints)
####

if train_target == 'response':
#### save latest response
bitod_preds[dial_id]["turns"][str(turn_id)]["response"] = predictions[-1]
####

with open('bitod_preds.json', 'w') as fout:
with open(f"{int(time.time())}_bitod_preds.json", 'w') as fout:
ujson.dump(bitod_preds, fout, indent=2, ensure_ascii=False)

if original_order is not None:
Expand Down

0 comments on commit 395b4bd

Please sign in to comment.