Skip to content

Commit

Permalink
Use print() function in both Python 2 and Python 3
Browse files Browse the repository at this point in the history
  • Loading branch information
cclauss committed Jun 21, 2019
1 parent 88015f0 commit c4fa1fc
Show file tree
Hide file tree
Showing 27 changed files with 93 additions and 88 deletions.
2 changes: 1 addition & 1 deletion convlab/env/movie.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,7 @@ def initialize_episode(self):

def action_decode(self, action):
""" DQN: Input state, output action """
if type(action) == np.ndarray:
if isinstance(action, np.ndarray):
action = action[0]
act_slot_response = deepcopy(self.feasible_actions[action])
return {'act_slot_response': act_slot_response, 'act_slot_value_response': None}
Expand Down
1 change: 1 addition & 0 deletions convlab/lib/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from importlib import reload
from pprint import pformat

import cv2
import numpy as np
import pandas as pd
import pydash as ps
Expand Down
4 changes: 2 additions & 2 deletions convlab/modules/dst/multiwoz/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _is_empty(slot, domain_state):
return False

def _is_match(value1, value2):
if type(value1) is not str or type(value2) is not str:
if not isinstance(value1, str) or not isinstance(value2, str):
return value1 == value2
value1 = value1.lower()
value2 = value2.lower()
Expand All @@ -179,7 +179,7 @@ def _is_match(value1, value2):
return False

def _fuzzy_match(value1, value2):
if type(value1) is not str or type(value2) is not str:
if not isinstance(value1, str) or not isinstance(value2, str):
return value1 == value2
value1 = value1.lower()
value2 = value2.lower()
Expand Down
2 changes: 1 addition & 1 deletion convlab/modules/dst/multiwoz/rule_dst.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self):

def update(self, user_act=None):
# print('------------------{}'.format(user_act))
if type(user_act) is not dict:
if not isinstance(user_act, dict):
raise Exception('Expect user_act to be <class \'dict\'> type but get {}.'.format(type(user_act)))
previous_state = self.state
new_belief_state = copy.deepcopy(previous_state['belief_state'])
Expand Down
4 changes: 2 additions & 2 deletions convlab/modules/e2e/multiwoz/Sequicity/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,13 +402,13 @@ def constraint_same(self, truth_cons, gen_cons):
def _get_entity_dict(self, entity_data):
entity_dict = {}
for k in entity_data:
if type(entity_data[k][0]) is str:
if isinstance(entity_data[k][0], str):
for entity in entity_data[k]:
entity = self._lemmatize(self._tokenize(entity))
entity_dict[entity] = k
if k in ['event','poi_type']:
entity_dict[entity.split()[0]] = k
elif type(entity_data[k][0]) is dict:
elif isinstance(entity_data[k][0], dict):
for entity_entry in entity_data[k]:
for entity_type, entity in entity_entry.items():
entity_type = 'poi_type' if entity_type == 'type' else entity_type
Expand Down
2 changes: 1 addition & 1 deletion convlab/modules/e2e/multiwoz/Sequicity/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def main(arg_mode=None, arg_model=None):
for pair in args.cfg:
k, v = tuple(pair.split('='))
dtype = type(getattr(cfg, k))
if dtype == type(None):
if isinstance(None, dtype):
raise ValueError()
if dtype is bool:
v = False if v == 'False' else True
Expand Down
10 changes: 5 additions & 5 deletions convlab/modules/e2e/multiwoz/Sequicity/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def db_degree_handler(self, z_samples, *args, **kwargs):
for cons_idx_list in z_samples:
constraints = set()
for cons in cons_idx_list:
if type(cons) is not str:
if not isinstance(cons, str):
cons = self.vocab.decode(cons)
if cons == 'EOS_Z1':
break
Expand Down Expand Up @@ -715,14 +715,14 @@ def _get_encoded_data(self, tokenized_data):
def _get_entity_dict(self, entity_data):
entity_dict = {}
for k in entity_data:
if type(entity_data[k][0]) is str:
if isinstance(entity_data[k][0], str):
for entity in entity_data[k]:
entity = self._lemmatize(self._tokenize(entity))
entity_dict[entity] = k
if k in ['event', 'poi_type']:
entity_dict[entity.split()[0]] = k
self.abbr_dict[entity.split()[0]] = entity
elif type(entity_data[k][0]) is dict:
elif isinstance(entity_data[k][0], dict):
for entity_entry in entity_data[k]:
for entity_type, entity in entity_entry.items():
entity_type = 'poi_type' if entity_type == 'type' else entity_type
Expand Down Expand Up @@ -753,7 +753,7 @@ def db_degree_handler(self, z_samples, idx=None, *args, **kwargs):
for i,cons_idx_list in enumerate(z_samples):
constraints = set()
for cons in cons_idx_list:
if type(cons) is not str:
if not isinstance(cons, str):
cons = self.vocab.decode(cons)
if cons == 'EOS_Z1':
break
Expand Down Expand Up @@ -902,7 +902,7 @@ def _get_encoded_data(self, tokenized_data):
def _get_clean_db(self, raw_db_data):
for entry in raw_db_data:
for k, v in list(entry.items()):
if type(v) != str or v == '?':
if not isinstance(v, str) or v == '?':
entry.pop(k)

def _construct(self, train_json_path, dev_json_path, test_json_path, db_json_path):
Expand Down
2 changes: 1 addition & 1 deletion convlab/modules/e2e/multiwoz/Sequicity/tsd_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def toss_(p):


def nan(v):
if type(v) is float:
if isinstance(v, float):
return v == float('nan')
return np.isnan(np.sum(v.data.cpu().numpy()))

Expand Down
4 changes: 2 additions & 2 deletions convlab/modules/nlg/multiwoz/sc_lstm/nlg_sc_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,13 @@ def generate_delex(self, meta):
domain, intent = k.split('-')
if intent == "Request":
for pair in v:
if type(pair[1]) != str:
if not isinstance(pair[1], str):
pair[1] = str(pair[1])
pair.insert(1, '?')
else:
counter = {}
for pair in v:
if type(pair[1]) != str:
if not isinstance(pair[1], str):
pair[1] = str(pair[1])
if pair[0] == 'Internet' or pair[0] == 'Parking':
pair.insert(1, 'yes')
Expand Down
5 changes: 2 additions & 3 deletions convlab/modules/nlu/multiwoz/svm/Features.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,7 @@ def get_ngrams(sentence, max_length, skip_ngrams=False, add_tags = True):
for n in range(1, max_length+1):
subsets = set(itertools.combinations(range(len(words)), n))
for subset in subsets:
subset = list(subset)
subset.sort()
subset = sorted(subset)
dists = [(subset[i]-subset[i-1]) for i in range(1, len(subset))]
out.append((" ".join([words[j] for j in subset]), dists))

Expand Down Expand Up @@ -267,7 +266,7 @@ def get_cnngrams(cnet, max_ngrams, max_length):
class cnNgram(object):

def __init__(self, words, logp, delta=0):
if type(words) != type([]) :
if not isinstance(words, type([])) :
words = words.split()
self.words = words
self.logp = logp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def main() :
elif args.focus.lower() == "false":
tracker = Tracker()
else:
raise RuntimeError,'Dont recognize focus=%s (must be True or False)' % (args.focus)
raise RuntimeError('Dont recognize focus=%s (must be True or False)' % (args.focus))
for call in dataset :
this_session = {"session-id":call.log["session-id"], "turns":[]}
tracker.reset()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import print_function
# Modified by Microsoft Corporation.
# Licensed under the MIT license.

Expand Down Expand Up @@ -81,7 +82,7 @@ def check(self):
self.add_error(("top level","wall-time should be included"))
else:
wall_time = self.tracker_output["wall-time"]
if type(wall_time) != type(0.0):
if not isinstance(wall_time, type(0.0)):
self.add_error(("top level","wall-time must be a float"))
elif wall_time <= 0.0 :
self.add_error(("top level","wall-time must be positive"))
Expand Down Expand Up @@ -192,11 +193,11 @@ def add_error(self, context, error_str):

def print_errors(self):
if len(self.errors) == 0 :
print "Found no errors, trackfile is valid"
print("Found no errors, trackfile is valid")
else :
print "Found",len(self.errors),"errors:"
print("Found",len(self.errors),"errors:")
for context, error in self.errors:
print " ".join(map(str, context)), "-", error
print(" ".join(map(str, context)), "-", error)



Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class dataset_walker(object):
def __init__(self,dataset,labels=False,dataroot=None):
if "[" in dataset :
self.datasets = json.loads(dataset)
elif type(dataset) == type([]) :
elif isinstance(dataset, type([])) :
self.datasets= dataset
else:
self.datasets = [dataset]
Expand Down
17 changes: 9 additions & 8 deletions convlab/modules/nlu/multiwoz/svm/corpora/scripts/prettyPrint.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import print_function
# Modified by Microsoft Corporation.
# Licensed under the MIT license.

Expand All @@ -12,7 +13,7 @@

def jsonItemToCued(i):
if len(i) > 2:
print "Unsure what to do about " + str(i)
print("Unsure what to do about " + str(i))
if len(i) > 1:
return i[0] + "=" + i[1]
elif len(i) == 1:
Expand All @@ -34,28 +35,28 @@ def prettyPrint(fname):
log = json.load(open(os.path.join(fname, "log.json")))
label = json.load(open(os.path.join(fname, "label.json")))
for turn, labelturn in zip(log["turns"], label["turns"]) :
print "SYS > " + turn['output']['transcript']
print("SYS > " + turn['output']['transcript'])
dact = turn['output']['dialog-acts']
slulist = turn['input']['live']['slu-hyps']
print "DAct > " + jsonToCued(dact)
print("DAct > " + jsonToCued(dact))
if len(slulist) > 0:
for s in slulist:
slu = s
#prob = slulist[0]['prob']
print "SLU > %-20s [%.2f]" % (jsonToCued(slu['slu-hyp']),slu['score'])
print("SLU > %-20s [%.2f]" % (jsonToCued(slu['slu-hyp']),slu['score']))

asrlist = turn['input']['live']['asr-hyps']
print "ASR > " + asrlist[0]['asr-hyp']
print "Tran > " +str(labelturn['transcription'])
print " "
print("ASR > " + asrlist[0]['asr-hyp'])
print("Tran > " +str(labelturn['transcription']))
print(" ")




if __name__ == "__main__":

if len(sys.argv) < 2:
print "Usage: python prettyPrint.py [dialogfolder]"
print("Usage: python prettyPrint.py [dialogfolder]")
else:
fname = sys.argv[1]
prettyPrint(fname)
Expand Down
26 changes: 13 additions & 13 deletions convlab/modules/nlu/multiwoz/svm/corpora/scripts/report.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import print_function
# Modified by Microsoft Corporation.
# Licensed under the MIT license.

Expand Down Expand Up @@ -78,35 +79,34 @@ def main(argv):
tables[state_component][EVALUATION_SCHEMES[(schedule, label_scheme)]][stat] = result

for state_component in ["goal.joint","method","requested.all"]:
print state_component.center(50)
evaluation_schemes = [key for key in tables[state_component].keys() if len(tables[state_component][key])>0]
evaluation_schemes.sort()
print(state_component.center(50))
evaluation_schemes = sorted([key for key in tables[state_component].keys() if len(tables[state_component][key])>0])
stats = tables[state_component][evaluation_schemes[0]].keys()
stats.sort()
print_row(['']+evaluation_schemes, header=True)
for stat in stats:
print_row([stat] + [tables[state_component][evaluation_scheme][stat] for evaluation_scheme in evaluation_schemes])

print "\n\n"
print("\n\n")




print ' featured metrics'
print(' featured metrics')
print_row(["","Joint Goals","Requested","Method"],header=True)
print_row(["Accuracy",tables["goal.joint"]["eval_2a"]["acc"],tables["requested.all"]["eval_2a"]["acc"],tables["method"]["eval_2a"]["acc"] ])
print_row(["l2",tables["goal.joint"]["eval_2a"]["l2"],tables["requested.all"]["eval_2a"]["l2"],tables["method"]["eval_2a"]["l2"] ])
print_row(["roc.v2_ca05",tables["goal.joint"]["eval_2a"]["roc.v2_ca05"],tables["requested.all"]["eval_2a"]["roc.v2_ca05"],tables["method"]["eval_2a"]["roc.v2_ca05"] ])


print "\n\n"
print("\n\n")


print ' basic stats'
print '-----------------------------------------------------------------------------------'
print(' basic stats')
print('-----------------------------------------------------------------------------------')
for k in sorted(basic_stats.keys()):
v = basic_stats[k]
print '%20s : %s' % (k,v)
print('%20s : %s' % (k,v))

def print_row(row, header=False):
out = [str(x) for x in row]
Expand All @@ -119,11 +119,11 @@ def print_row(row, header=False):
out = ("|".join(out))[:-1]+"|"

if header:
print "-"*len(out)
print out
print "-"*len(out)
print("-"*len(out))
print(out)
print("-"*len(out))
else:
print out
print(out)


if (__name__ == '__main__'):
Expand Down
Loading

0 comments on commit c4fa1fc

Please sign in to comment.