Skip to content

Commit

Permalink
v2 to v3
Browse files Browse the repository at this point in the history
  • Loading branch information
LDOUBLEV committed Apr 28, 2022
1 parent c8aa934 commit aac628b
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 124 deletions.
12 changes: 9 additions & 3 deletions deploy/slim/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@
### 1. 安装PaddleSlim

```bash
git clone https://github.com/PaddlePaddle/PaddleSlim.git
cd PaddleSlim
python setup.py install
pip3 install paddleslim==2.2.2
```

### 2. 准备训练好的模型
Expand All @@ -43,7 +41,15 @@ python deploy/slim/quantization/quant.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar
tar -xf ch_ppocr_mobile_v2.0_det_train.tar
python deploy/slim/quantization/quant.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrained_model=./ch_ppocr_mobile_v2.0_det_train/best_accuracy Global.save_model_dir=./output/quant_model
```

模型蒸馏和模型量化可以同时使用,以PPOCRv3检测模型为例:
```
# 下载检测预训练模型:
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar
tar xf https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar
python deploy/slim/quantization/quant.py -c configs/det/ch_PP-OCRv3_det/ch_PP-OCRv3_det_cml.yml -o Global.pretrained_model='./ch_PP-OCRv3_det_distill_train/best_accuracy' Global.save_model_dir=./output/quant_model_distill/
```
如果要训练识别模型的量化,修改配置文件和加载的模型参数即可。

Expand Down
15 changes: 12 additions & 3 deletions deploy/slim/quantization/README_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ After training, if you want to further compress the model size and accelerate th
### 1. Install PaddleSlim

```bash
git clone https://github.com/PaddlePaddle/PaddleSlim.git
cd PaddlSlim
python setup.py install
pip3 install paddleslim==2.2.2
```


Expand All @@ -52,6 +50,17 @@ python deploy/slim/quantization/quant.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3
```


Model distillation and model quantization can be used at the same time, taking the PPOCRv3 detection model as an example:
```
# download provided model
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar
tar xf https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar
python deploy/slim/quantization/quant.py -c configs/det/ch_PP-OCRv3_det/ch_PP-OCRv3_det_cml.yml -o Global.pretrained_model='./ch_PP-OCRv3_det_distill_train/best_accuracy' Global.save_model_dir=./output/quant_model_distill/
```

If you want to quantify the text recognition model, you can modify the configuration file and loaded model parameters.

### 4. Export inference model

Once we got the model after pruning and fine-tuning, we can export it as an inference model for the deployment of predictive tasks:
Expand Down
97 changes: 36 additions & 61 deletions doc/doc_ch/knowledge_distillation.md
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,9 @@ paddle.save(s_params, "ch_PP-OCRv2_rec_train/student.pdparams")
<a name="22"></a>
### 2.2 检测配置文件解析

检测模型蒸馏的配置文件在PaddleOCR/configs/det/ch_PP-OCRv2/目录下,包含三个蒸馏配置文件:
- ch_PP-OCRv2_det_cml.yml,采用cml蒸馏,采用一个大模型蒸馏两个小模型,且两个小模型互相学习的方法
- ch_PP-OCRv2_det_dml.yml,采用DML的蒸馏,两个Student模型互蒸馏的方法
- ch_PP-OCRv2_det_distill.yml,采用Teacher大模型蒸馏小模型Student的方法
检测模型蒸馏的配置文件在PaddleOCR/configs/det/ch_PP-OCRv3/目录下,包含两个个蒸馏配置文件:
- ch_PP-OCRv3_det_cml.yml,采用cml蒸馏,采用一个大模型蒸馏两个小模型,且两个小模型互相学习的方法
- ch_PP-OCRv3_det_dml.yml,采用DML的蒸馏,两个Student模型互蒸馏的方法

<a name="221"></a>
#### 2.2.1 模型结构
Expand All @@ -321,44 +320,44 @@ Architecture:
algorithm: Distillation # 算法名称
Models: # 模型,包含子网络的配置信息
Student: # 子网络名称,至少需要包含`pretrained`与`freeze_params`信息,其他的参数为子网络的构造参数
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params: false # 是否需要固定参数
return_all_feats: false # 子网络的参数,表示是否需要返回所有的features,如果为False,则只返回最后的输出
model_type: det
algorithm: DB
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
name: ResNet
in_channels: 3
layers: 50
Neck:
name: DBFPN
out_channels: 96
name: LKPAN
out_channels: 256
Head:
name: DBHead
kernel_list: [7,2,2]
k: 50
Teacher: # 另外一个子网络,这里给的是普通大模型蒸小模型的蒸馏示例,
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
freeze_params: true # Teacher模型是训练好的,不需要参与训练,freeze_params设置为True
Teacher: # 另外一个子网络,这里给的是DML蒸馏示例,
freeze_params: true
return_all_feats: false
model_type: det
algorithm: DB
Transform:
Backbone:
name: ResNet
layers: 18
in_channels: 3
layers: 50
Neck:
name: DBFPN
name: LKPAN
out_channels: 256
Head:
name: DBHead
kernel_list: [7,2,2]
k: 50
```

如果是采用DML,即两个小模型互相学习的方法,上述配置文件里的Teacher网络结构需要设置为Student模型一样的配置,具体参考配置文件[ch_PP-OCRv2_det_dml.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_dml.yml)
如果是采用DML,即两个小模型互相学习的方法,上述配置文件里的Teacher网络结构需要设置为Student模型一样的配置,具体参考配置文件[ch_PP-OCRv3_det_dml.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml)

下面介绍[ch_PP-OCRv2_det_cml.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml)的配置文件参数:
下面介绍[ch_PP-OCRv3_det_cml.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml)的配置文件参数:

```
Architecture:
Expand All @@ -375,12 +374,14 @@ Architecture:
Transform:
Backbone:
name: ResNet
layers: 18
in_channels: 3
layers: 50
Neck:
name: DBFPN
name: LKPAN
out_channels: 256
Head:
name: DBHead
kernel_list: [7,2,2]
k: 50
Student: # CML蒸馏的Student模型配置
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
Expand All @@ -392,10 +393,11 @@ Architecture:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
disable_se: true
Neck:
name: DBFPN
name: RSEFPN
out_channels: 96
shortcut: True
Head:
name: DBHead
k: 50
Expand All @@ -410,10 +412,11 @@ Architecture:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
disable_se: true
Neck:
name: DBFPN
name: RSEFPN
out_channels: 96
shortcut: True
Head:
name: DBHead
k: 50
Expand Down Expand Up @@ -445,34 +448,7 @@ Architecture:
<a name="222"></a>
#### 2.2.2 损失函数

知识蒸馏任务中,检测ch_PP-OCRv2_det_distill.yml蒸馏损失函数配置如下所示。

```yaml
Loss:
name: CombinedLoss # 损失函数名称,基于改名称,构建用于损失函数的类
loss_config_list: # 损失函数配置文件列表,为CombinedLoss的必备函数
- DistillationDilaDBLoss: # 基于蒸馏的DB损失函数,继承自标准的DBloss
weight: 1.0 # 损失函数的权重,loss_config_list中,每个损失函数的配置都必须包含该字段
model_name_pairs: # 对于蒸馏模型的预测结果,提取这两个子网络的输出,计算Teacher模型和Student模型输出的loss
- ["Student", "Teacher"]
key: maps # 取子网络输出dict中,该key对应的tensor
balance_loss: true # 以下几个参数为标准DBloss的配置参数
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
- DistillationDBLoss: # 基于蒸馏的DB损失函数,继承自标准的DBloss,用于计算Student和GT之间的loss
weight: 1.0
model_name_list: ["Student"] # 模型名字只有Student,表示计算Student和GT之间的loss
name: DBLoss
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
```

同理,检测ch_PP-OCRv2_det_cml.yml蒸馏损失函数配置如下所示。相比较于ch_PP-OCRv2_det_distill.yml的损失函数配置,cml蒸馏的损失函数配置做了3个改动:
检测ch_PP-OCRv3_det_cml.yml蒸馏损失函数配置如下所示。相比较于ch_PP-OCRv3_det_distill.yml的损失函数配置,cml蒸馏的损失函数配置做了3个改动:
```yaml
Loss:
name: CombinedLoss
Expand Down Expand Up @@ -545,34 +521,33 @@ Metric:
<a name="225"></a>
#### 2.2.5 检测蒸馏模型finetune

检测蒸馏有三种方式:
- 采用ch_PP-OCRv2_det_distill.yml,Teacher模型设置为PaddleOCR提供的模型或者您训练好的大模型
- 采用ch_PP-OCRv2_det_cml.yml,采用cml蒸馏,同样Teacher模型设置为PaddleOCR提供的模型或者您训练好的大模型
- 采用ch_PP-OCRv2_det_dml.yml,采用DML的蒸馏,两个Student模型互蒸馏的方法,在PaddleOCR采用的数据集上大约有1.7%的精度提升。
PP-OCRv3检测蒸馏有两种方式:
- 采用ch_PP-OCRv3_det_cml.yml,采用cml蒸馏,同样Teacher模型设置为PaddleOCR提供的模型或者您训练好的大模型
- 采用ch_PP-OCRv3_det_dml.yml,采用DML的蒸馏,两个Student模型互蒸馏的方法,在PaddleOCR采用的数据集上相比单独训练Student模型有1%-2%的提升。

在具体fine-tune时,需要在网络结构的`pretrained`参数中设置要加载的预训练模型。

在精度提升方面,cml的精度>dml的精度>distill蒸馏方法的精度。当数据量不足或者Teacher模型精度与Student精度相差不大的时候,这个结论或许会改变。
在精度提升方面,cml的精度>dml的精度蒸馏方法的精度。当数据量不足或者Teacher模型精度与Student精度相差不大的时候,这个结论或许会改变。


另外,由于PaddleOCR提供的蒸馏预训练模型包含了多个模型的参数,如果您希望提取Student模型的参数,可以参考如下代码:
```
# 下载蒸馏训练模型的参数
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_distill_train.tar
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv3_det_distill_train.tar
```

```python
import paddle
# 加载预训练模型
all_params = paddle.load("ch_PP-OCRv2_det_distill_train/best_accuracy.pdparams")
all_params = paddle.load("ch_PP-OCRv3_det_distill_train/best_accuracy.pdparams")
# 查看权重参数的keys
print(all_params.keys())
# 学生模型的权重提取
s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Student." in key}
# 查看学生模型权重参数的keys
print(s_params.keys())
# 保存
paddle.save(s_params, "ch_PP-OCRv2_det_distill_train/student.pdparams")
paddle.save(s_params, "ch_PP-OCRv3_det_distill_train/student.pdparams")
```

最终`Student`模型的参数将会保存在`ch_PP-OCRv2_det_distill_train/student.pdparams`中,用于模型的fine-tune。
最终`Student`模型的参数将会保存在`ch_PP-OCRv3_det_distill_train/student.pdparams`中,用于模型的fine-tune。
Loading

0 comments on commit aac628b

Please sign in to comment.