Skip to content

Commit

Permalink
add pretrained params to backbone
Browse files Browse the repository at this point in the history
  • Loading branch information
WenmuZhou committed Jan 6, 2022
1 parent 9ecfc34 commit cd7b2ea
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 26 deletions.
3 changes: 1 addition & 2 deletions configs/vqa/re/layoutxlm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ Global:
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False
pretrained_model: &pretrained_model layoutxlm-base-uncased # This field can only be changed by modifying the configuration file
save_inference_dir:
use_visualdl: False
infer_img: doc/vqa/input/zh_val_21.jpg
Expand All @@ -20,7 +19,7 @@ Architecture:
Transform:
Backbone:
name: LayoutXLMForRe
pretrained_model: *pretrained_model
pretrained: True
checkpoints:

Loss:
Expand Down
3 changes: 1 addition & 2 deletions configs/vqa/ser/layoutlm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ Global:
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False
pretrained_model: &pretrained_model layoutlm-base-uncased # This field can only be changed by modifying the configuration file
save_inference_dir:
use_visualdl: False
infer_img: doc/vqa/input/zh_val_0.jpg
Expand All @@ -20,7 +19,7 @@ Architecture:
Transform:
Backbone:
name: LayoutLMForSer
pretrained_model: *pretrained_model
pretrained: True
checkpoints:
num_classes: &num_classes 7

Expand Down
3 changes: 1 addition & 2 deletions configs/vqa/ser/layoutxlm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ Global:
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False
pretrained_model: &pretrained_model layoutxlm-base-uncased # This field can only be changed by modifying the configuration file
save_inference_dir:
use_visualdl: False
infer_img: doc/vqa/input/zh_val_42.jpg
Expand All @@ -20,7 +19,7 @@ Architecture:
Transform:
Backbone:
name: LayoutXLMForSer
pretrained_model: *pretrained_model
pretrained: True
checkpoints:
num_classes: &num_classes 7

Expand Down
42 changes: 22 additions & 20 deletions ppocr/modeling/backbones/vqa_layoutlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,32 @@

__all__ = ["LayoutXLMForSer", 'LayoutLMForSer']

pretrained_model_dict = {
LayoutXLMModel: 'layoutxlm-base-uncased',
LayoutLMModel: 'layoutlm-base-uncased'
}


class NLPBaseModel(nn.Layer):
def __init__(self,
base_model_class,
model_class,
type='ser',
pretrained_model=None,
pretrained=True,
checkpoints=None,
**kwargs):
super(NLPBaseModel, self).__init__()
assert pretrained_model is not None or checkpoints is not None, "one of pretrained_model and checkpoints must be not None"
if checkpoints is not None:
self.model = model_class.from_pretrained(checkpoints)
else:
base_model = base_model_class.from_pretrained(pretrained_model)
pretrained_model_name = pretrained_model_dict[base_model_class]
if pretrained:
base_model = base_model_class.from_pretrained(
pretrained_model_name)
else:
base_model = base_model_class(
**base_model_class.pretrained_init_configuration[
pretrained_model_name])
if type == 'ser':
self.model = model_class(
base_model, num_classes=kwargs['num_classes'], dropout=None)
Expand All @@ -48,16 +59,13 @@ def __init__(self,


class LayoutXLMForSer(NLPBaseModel):
def __init__(self,
num_classes,
pretrained_model='layoutxlm-base-uncased',
checkpoints=None,
def __init__(self, num_classes, pretrained=True, checkpoints=None,
**kwargs):
super(LayoutXLMForSer, self).__init__(
LayoutXLMModel,
LayoutXLMForTokenClassification,
'ser',
pretrained_model,
pretrained,
checkpoints,
num_classes=num_classes)

Expand All @@ -75,16 +83,13 @@ def forward(self, x):


class LayoutLMForSer(NLPBaseModel):
def __init__(self,
num_classes,
pretrained_model='layoutxlm-base-uncased',
checkpoints=None,
def __init__(self, num_classes, pretrained=True, checkpoints=None,
**kwargs):
super(LayoutLMForSer, self).__init__(
LayoutLMModel,
LayoutLMForTokenClassification,
'ser',
pretrained_model,
pretrained,
checkpoints,
num_classes=num_classes)

Expand All @@ -100,13 +105,10 @@ def forward(self, x):


class LayoutXLMForRe(NLPBaseModel):
def __init__(self,
pretrained_model='layoutxlm-base-uncased',
checkpoints=None,
**kwargs):
super(LayoutXLMForRe, self).__init__(
LayoutXLMModel, LayoutXLMForRelationExtraction, 're',
pretrained_model, checkpoints)
def __init__(self, pretrained=True, checkpoints=None, **kwargs):
super(LayoutXLMForRe, self).__init__(LayoutXLMModel,
LayoutXLMForRelationExtraction,
're', pretrained, checkpoints)

def forward(self, x):
x = self.model(
Expand Down

0 comments on commit cd7b2ea

Please sign in to comment.