Skip to content

Commit

Permalink
update data
Browse files Browse the repository at this point in the history
  • Loading branch information
zqwerty committed May 28, 2019
1 parent bf935ff commit 63d6ffe
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 7 deletions.
127 changes: 120 additions & 7 deletions data/multiwoz/annotation/annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,12 @@ def phrase_in_utt(phrase, utt):
phrases.append(digit2word[phrase_low])
elif phrase_low in word2digit:
phrases.append(word2digit[phrase_low])
else:
if ' '+phrase_low in utt_low or utt_low.startswith(phrase_low):
return True
else:
return False

for w in phrases:
if utt_low.startswith(w) or utt_low.endswith(w):
return True
Expand All @@ -200,16 +206,24 @@ def phrase_idx_utt(phrase, utt):
phrases.append(digit2word[phrase_low])
elif phrase_low in word2digit:
phrases.append(word2digit[phrase_low])
else:
if ' '+phrase_low in utt_low or utt_low.startswith(phrase_low):
return get_idx(phrase_low, utt_low)
else:
return None

for w in phrases:
if utt_low.startswith(w) or utt_low.endswith(w):
return get_idx(w, utt_low)
elif ' '+w+' ' in utt_low:
return get_idx(w, utt_low)
return get_idx(' '+w+' ', utt_low)
# elif w+'-star' in utt_low:
# return get_idx(w, utt_low)
return None


def get_idx(phrase, utt):
char_index_begin = utt.lower().index(phrase.lower())
char_index_begin = utt.index(phrase)
char_index_end = char_index_begin + len(phrase)
word_index_begin = len(utt[:char_index_begin].split())
word_index_end = len(utt[:char_index_end].split()) - 1
Expand All @@ -223,7 +237,7 @@ def annotate_user_da(data):

domains = ['taxi', 'police', 'hospital', 'hotel', 'attraction', 'train', 'restaurant']

# nlp = spacy.load('en_core_web_sm')
nlp = spacy.load('en_core_web_sm')

for no, session in data.items():
user_das = []
Expand Down Expand Up @@ -255,14 +269,18 @@ def annotate_user_da(data):
if slot in subtable and subtable[slot] != 'not mentioned':
value_state = subtable[slot]
# state for that slot change
value_state = value_state.lower()
# value_state = value_state.lower()
if value_state != '':
value_state = ' '.join([token.text for token in nlp(value_state)]).strip()

value_goal = ''
if slot in user_goal[domain]['info']:
value_goal = user_goal[domain]['info'][slot]
elif 'book' in user_goal[domain] and slot in user_goal[domain]['book']:
value_goal = user_goal[domain]['book'][slot]
value_goal = value_goal.lower()
# value_goal = value_goal.lower()
if value_goal != '':
value_goal = ' '.join([token.text for token in nlp(value_goal)]).strip()

# slot-value appear in goal
slot_in_da = REF_USR_DA[domain.capitalize()][slot]
Expand All @@ -275,6 +293,10 @@ def annotate_user_da(data):
elif phrase_in_utt(slot, user_utterance):
# slot in user utterance
da[domain.capitalize() + '-Inform'].append([slot_in_da, value_state])
elif slot == 'stars' and (
phrase_in_utt(value_state + '-star', user_utterance) or
phrase_in_utt(value_state + '-stars', user_utterance)):
da[domain.capitalize() + '-Inform'].append([slot_in_da, value_state])
elif slot == 'people' and phrase_in_utt('one person', user_utterance):
# keyword 'person' for people
da[domain.capitalize() + '-Inform'].append([slot_in_da, "1"])
Expand Down Expand Up @@ -317,8 +339,11 @@ def annotate_user_da(data):
# digital value
if phrase_in_utt(value_goal, user_utterance):
if slot == 'stars':
if phrase_in_utt('star', user_utterance) or phrase_in_utt('stars',
user_utterance):
if phrase_in_utt('star ', user_utterance) or \
phrase_in_utt('stars ',user_utterance) or \
user_utterance.lower().endswith('star') or \
user_utterance.lower().endswith('stars') or \
'-star' in user_utterance.lower():
da[domain.capitalize() + '-Inform'].append([slot_in_da, value_goal])
elif slot == 'stay':
if phrase_in_utt('stay', user_utterance) or \
Expand Down Expand Up @@ -574,6 +599,16 @@ def annotate_span(data):
is_annotated = True
word_index_begin, word_index_end = phrase_idx_utt(v, utterance)
span_info.append((da, s, v, word_index_begin, word_index_end))
elif s == 'Stars':
pattern = ''
if phrase_in_utt(v+'-star', utterance):
pattern = v+'-star'
elif phrase_in_utt(v+'-stars', utterance):
pattern = v+'-stars'
if pattern:
is_annotated = True
word_index_begin, word_index_end = phrase_idx_utt(pattern, utterance)
span_info.append((da, s, v, word_index_begin, word_index_end))
elif phrase_in_utt('same', utterance) and phrase_in_utt(s, utterance):
# coreference-'same'
if phrase_in_utt('same ' + s, utterance):
Expand Down Expand Up @@ -633,6 +668,78 @@ def annotate_span(data):
session['log'][i]['span_info'] = span_info


def post_process_span(data):
for no, session in data.items():
for i in range(0, len(session['log']), 1):
das = session['log'][i]['dialog_act']
utterance = session['log'][i]['text']
span_info = session['log'][i]['span_info']
start_end_pos = dict()
for act, slot, value, start, end in span_info:
if (start,end) in start_end_pos:
start_end_pos[(start,end)].append([act, slot, value, start, end])
else:
start_end_pos[(start,end)] = [[act, slot, value, start, end]]
for start_end in start_end_pos:
if len(start_end_pos[start_end]) > 1:
value = [x[2] for x in start_end_pos[start_end]]
if len(set(value))>1:
# print(utterance)
# print(start_end_pos[start_end])
for ele in start_end_pos[start_end]:
v = ele[2]
if utterance.startswith(v+' '):
new_span = get_idx(v+' ',utterance)
elif ' '+v+' ' in utterance:
new_span = get_idx(' ' + v + ' ', utterance)
else:
new_span = None
if new_span:
ele[3], ele[4] = new_span
# print(start_end_pos[start_end])
else:
# one value
for ele in start_end_pos[start_end]:
slot = ele[1]
v = ele[2]
if slot == 'People':
pattern = ''
if phrase_in_utt('people', utterance):
pattern = 'people'
elif phrase_in_utt('person', utterance):
pattern = 'person'
if pattern:
slot_span = phrase_idx_utt(pattern, utterance)
v_set = [v]
if v in digit2word:
v_set.append(digit2word[v])
elif v in word2digit:
v_set.append(word2digit[v])
if utterance.split()[slot_span[0]-1] in v_set:
ele[3], ele[4] = slot_span[0]-1, slot_span[1]-1

elif slot == 'Stay':
pattern = ''
if phrase_in_utt('night', utterance):
pattern = 'night'
elif phrase_in_utt('nights', utterance):
pattern = 'nights'
if pattern:
slot_span = phrase_idx_utt(pattern, utterance)
v_set = [v]
if v in digit2word:
v_set.append(digit2word[v])
elif v in word2digit:
v_set.append(word2digit[v])
if utterance.split()[slot_span[0]-1] in v_set:
ele[3], ele[4] = slot_span[0]-1, slot_span[1]-1
# print(start_end_pos[start_end])
new_span_info = [x for y in start_end_pos.values() for x in y]
session['log'][i]['span_info'] = new_span_info




if __name__ == '__main__':
un_zip('MULTIWOZ2.zip')
dir_name = 'MULTIWOZ2 2/'
Expand Down Expand Up @@ -665,11 +772,17 @@ def annotate_span(data):
else:
turn['dialog_act'] = da
print('dataset size: %d' % len(all_data))
# all_data = dict(list(all_data.items())[-100:])
tokenize(all_data, process_text=True, process_da=True, process_ref=True)
annotate_user_da(all_data)
annotate_sys_da(all_data, database)
tokenize(all_data, process_text=False, process_da=True, process_ref=False)
annotate_span(all_data)
# archive = zipfile.ZipFile('annotated_user_da_with_span_full.json.zip', 'r')
# all_data = json.load(archive.open('annotated_user_da_with_span_full.json'))
# annotate_span(all_data)
post_process_span(all_data)
# # json.dump(all_data, open('test.json', 'w'), indent=4)
json.dump(all_data, open('annotated_user_da_with_span_full.json', 'w'), indent=4)
with zipfile.ZipFile('annotated_user_da_with_span_full.json.zip', 'w', zipfile.ZIP_DEFLATED) as zf:
zf.write('annotated_user_da_with_span_full.json')
Expand Down
Binary file not shown.

0 comments on commit 63d6ffe

Please sign in to comment.