Skip to content

Commit

Permalink
fix re infer bug
Browse files Browse the repository at this point in the history
  • Loading branch information
WenmuZhou committed Jan 12, 2022
1 parent 99de035 commit c703a58
Showing 1 changed file with 34 additions and 31 deletions.
65 changes: 34 additions & 31 deletions ppocr/data/imaug/label_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,19 +833,20 @@ def __call__(self, data):
segment_offset_id = []
gt_label_list = []

if self.contains_re:
# for re
entities = []
if not self.infer_mode:
relations = []
id2label = {}
entity_id_to_index_map = {}
empty_entity = set()
entities = []

# for re
train_re = self.contains_re and not self.infer_mode
if train_re:
relations = []
id2label = {}
entity_id_to_index_map = {}
empty_entity = set()

data['ocr_info'] = copy.deepcopy(ocr_info)

for info in ocr_info:
if self.contains_re and not self.infer_mode:
if train_re:
# for re
if len(info["text"]) == 0:
empty_entity.add(info["id"])
Expand All @@ -872,24 +873,22 @@ def __call__(self, data):
gt_label = self._parse_label(label, encode_res)

# construct entities for re
if self.contains_re:
if not self.infer_mode:
if gt_label[0] != self.label2id_map["O"]:
entity_id_to_index_map[info["id"]] = len(entities)
label = label.upper()
entities.append({
"start": len(input_ids_list),
"end":
len(input_ids_list) + len(encode_res["input_ids"]),
"label": label.upper(),
})
else:
if train_re:
if gt_label[0] != self.label2id_map["O"]:
entity_id_to_index_map[info["id"]] = len(entities)
label = label.upper()
entities.append({
"start": len(input_ids_list),
"end":
len(input_ids_list) + len(encode_res["input_ids"]),
"label": 'O',
"label": label.upper(),
})
else:
entities.append({
"start": len(input_ids_list),
"end": len(input_ids_list) + len(encode_res["input_ids"]),
"label": 'O',
})
input_ids_list.extend(encode_res["input_ids"])
token_type_ids_list.extend(encode_res["token_type_ids"])
bbox_list.extend([bbox] * len(encode_res["input_ids"]))
Expand All @@ -908,19 +907,23 @@ def __call__(self, data):
padding_side=self.tokenizer.padding_side,
pad_token_type_id=self.tokenizer.pad_token_type_id,
pad_token_id=self.tokenizer.pad_token_id)
data['entities'] = entities

if self.contains_re:
data['entities'] = entities
if self.infer_mode:
data['ocr_info'] = ocr_info
else:
data['relations'] = relations
data['id2label'] = id2label
data['empty_entity'] = empty_entity
data['entity_id_to_index_map'] = entity_id_to_index_map
if train_re:
data['relations'] = relations
data['id2label'] = id2label
data['empty_entity'] = empty_entity
data['entity_id_to_index_map'] = entity_id_to_index_map
return data

def _load_ocr_info(self, data):
def trans_poly_to_bbox(poly):
x1 = np.min([p[0] for p in poly])
x2 = np.max([p[0] for p in poly])
y1 = np.min([p[1] for p in poly])
y2 = np.max([p[1] for p in poly])
return [x1, y1, x2, y2]

if self.infer_mode:
ocr_result = self.ocr_engine.ocr(data['image'], cls=False)
ocr_info = []
Expand Down

0 comments on commit c703a58

Please sign in to comment.