Skip to content

Commit

Permalink
Merge pull request ConvLab#17 from ConvLab/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
sungjinl committed Jun 17, 2019
2 parents 74cfda2 + f52c4d6 commit 41f4e0c
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 9 deletions.
9 changes: 9 additions & 0 deletions convlab/evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,13 @@ def inform_F1(self, ref2goal=True, aggregate=True):
raise NotImplementedError

def task_success(self, ref2goal=True):
"""
judge if all the domains are successfully completed
"""
raise NotImplementedError

def domain_success(self, domain, ref2goal=True):
"""
judge if the domain (subtask) is successfully completed
"""
raise NotImplementedError
67 changes: 58 additions & 9 deletions convlab/evaluator/multiwoz.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-

import re

import numpy as np
from copy import deepcopy

from convlab.evaluator.evaluator import Evaluator
from convlab.modules.util.multiwoz.dbquery import dbs
Expand Down Expand Up @@ -46,7 +46,8 @@ def _init_dict_booked(self):
dic[domain] = None
return dic

def _expand(self, goal):
def _expand(self, _goal):
goal = deepcopy(_goal)
for domain in belief_domains:
if domain not in goal:
goal[domain] = {'info':{}, 'book':{}, 'reqt':[]}
Expand Down Expand Up @@ -84,6 +85,7 @@ def add_sys_da(self, da_turn):
slot_pair = da_turn[dom_int]
for slot, value in slot_pair:
da = (dom_int +'-'+slot).lower()
value = str(value)
self.sys_da_array.append(da+'-'+value)

if da == 'booking-book-ref' and self.cur_domain in ['hotel', 'restaurant', 'train']:
Expand Down Expand Up @@ -111,12 +113,14 @@ def add_usr_da(self, da_turn):
da = (dom_int +'-'+slot).lower()
self.usr_da_array.append(da+'-'+value)

def _book_rate_goal(self, goal, booked_entity):
def _book_rate_goal(self, goal, booked_entity, domains=None):
"""
judge if the selected entity meets the constraint
"""
if domains is None:
domains = belief_domains
score = []
for domain in belief_domains:
for domain in domains:
if goal[domain]['book']:
tot = len(goal[domain]['info'].keys())
if tot == 0:
Expand Down Expand Up @@ -155,19 +159,21 @@ def _book_rate_goal(self, goal, booked_entity):
score.append(match / tot)
return score

def _inform_F1_goal(self, goal, sys_history):
def _inform_F1_goal(self, goal, sys_history, domains=None):
"""
judge if all the requested information is answered
"""
if domains is None:
domains = belief_domains
inform_slot = {}
for domain in belief_domains:
for domain in domains:
inform_slot[domain] = set()
for da in sys_history:
domain, intent, slot, value = da.split('-', 3)
if intent in ['inform', 'recommend', 'offerbook', 'offerbooked'] and domain in belief_domains and slot in mapping[domain]:
if intent in ['inform', 'recommend', 'offerbook', 'offerbooked'] and domain in domains and slot in mapping[domain]:
inform_slot[domain].add(mapping[domain][slot])
TP, FP, FN = 0, 0, 0
for domain in belief_domains:
for domain in domains:
for k in goal[domain]['reqt']:
if k in inform_slot[domain]:
TP += 1
Expand Down Expand Up @@ -206,7 +212,7 @@ def inform_F1(self, ref2goal=True, aggregate=True):
goal = self._init_dict()
for da in self.usr_da_array:
d, i, s, v = da.split('-', 3)
if i in ['inform', 'recommend', 'offerbook', 'offerbooked'] and s in mapping[d]:
if i == 'inform' and s in mapping[d]:
goal[d]['info'][mapping[d][s]] = v
elif i == 'request':
goal[d]['reqt'].append(s)
Expand All @@ -226,6 +232,9 @@ def inform_F1(self, ref2goal=True, aggregate=True):
return [TP, FP, FN]

def task_success(self, ref2goal=True):
"""
judge if all the domains are successfully completed
"""
book_sess = self.book_rate(ref2goal)
inform_sess = self.inform_F1(ref2goal)
# book rate == 1 & inform recall == 1
Expand All @@ -235,3 +244,43 @@ def task_success(self, ref2goal=True):
return 1
else:
return 0

def domain_success(self, domain, ref2goal=True):
"""
judge if the domain (subtask) is successfully completed
"""
if domain not in self.goal:
return None

if ref2goal:
goal = {}
goal[domain] = deepcopy(self.goal[domain])
else:
goal = {}
goal[domain] = {'info':{}, 'book':{}, 'reqt':[]}
if 'book' in self.goal[domain]:
goal[domain]['book'] = self.goal[domain]['book']
for da in self.usr_da_array:
d, i, s, v = da.split('-', 3)
if d != domain:
continue
if i == 'inform' and s in mapping[d]:
goal[d]['info'][mapping[d][s]] = v
elif i == 'request':
goal[d]['reqt'].append(s)

book_rate = self._book_rate_goal(goal, self.booked, [domain])
book_rate = np.mean(book_rate) if book_rate else None

inform = self._inform_F1_goal(goal, self.sys_da_array, [domain])
try:
inform_rec = inform[0] / (inform[0] + inform[2])
except ZeroDivisionError:
inform_rec = None

if (book_rate == 1 and inform_rec == 1) \
or (book_rate == 1 and inform_rec is None) \
or (book_rate is None and inform_rec == 1):
return 1
else:
return 0

0 comments on commit 41f4e0c

Please sign in to comment.