Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

[SCRIPT] Improve BERT fine-tuning script with AdamW optimizer, bucketing and gradient accumulation #482

Merged
merged 12 commits into from
Dec 30, 2018

Conversation

eric-haibin-lin
Copy link
Member

@eric-haibin-lin eric-haibin-lin commented Dec 25, 2018

Description

Checklist

Essentials

  • PR's title starts with a category (e.g. [BUGFIX], [MODEL], [TUTORIAL], [FEATURE], [DOC], etc)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • Code is well-documented

Changes

  • Feature1, tests, (and when applicable, API doc)
  • Feature2, tests, (and when applicable, API doc)

Comments

  • If this change is a backward incompatible change, why must this change be made.
  • Interesting edge cases to note here

@eric-haibin-lin eric-haibin-lin added this to In progress in GluonNLP 0.6.0 Dec 25, 2018
@mli
Copy link
Member

mli commented Dec 26, 2018

Job PR-482/3 is complete.
Docs are uploaded to http:https://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-482/3/index.html

@codecov
Copy link

codecov bot commented Dec 26, 2018

Codecov Report

Merging #482 into master will decrease coverage by 0.31%.
The diff coverage is 0%.

@@            Coverage Diff             @@
##           master     #482      +/-   ##
==========================================
- Coverage   71.51%   71.19%   -0.32%     
==========================================
  Files         119      118       -1     
  Lines        9998     9979      -19     
==========================================
- Hits         7150     7105      -45     
- Misses       2848     2874      +26
Flag Coverage Δ
#PR470 ?
#PR482 71.19% <0%> (?)
#master ?
#notserial 46.7% <0%> (-0.41%) ⬇️
#py2 70.95% <0%> (-0.32%) ⬇️
#py3 71.04% <0%> (-0.32%) ⬇️
#serial 56.76% <0%> (+0.02%) ⬆️

@codecov
Copy link

codecov bot commented Dec 26, 2018

Codecov Report

Merging #482 into master will increase coverage by 4.13%.
The diff coverage is 13.88%.

@@            Coverage Diff             @@
##           master     #482      +/-   ##
==========================================
+ Coverage   66.98%   71.11%   +4.13%     
==========================================
  Files         122      121       -1     
  Lines       10667    10040     -627     
==========================================
- Hits         7145     7140       -5     
+ Misses       3522     2900     -622
Flag Coverage Δ
#PR481 ?
#PR482 71.11% <13.88%> (-0.07%) ⬇️
#PR489 ?
#master ?
#notserial 46.85% <13.88%> (+2.87%) ⬆️
#py2 70.87% <13.88%> (+4.13%) ⬆️
#py3 70.96% <13.88%> (+3.56%) ⬆️
#serial 56.55% <13.88%> (+3.36%) ⬆️

@haven-jeon
Copy link
Member

@eric-haibin-lin could you check this out of memory? (tested on my K80 GPU)
I'm not sure bucket sampler requires more GPU memory than simple DataLoader.
It worked well on K80 GPU using simple DataLoader.

(mxnet_work) ubuntu@ip-172-31-19-247:~/work/mxnet_bert/eric/gluon-nlp/scripts/bert$ GLUE_DIR=~/share/glue_data/ python3 finetune_classifier.py --batch_size 32 --optimizer adamw --epochs 2 --gpu --seed 2 --lr 6e-5
/home/ubuntu/mxnet_work/lib/python3.6/site-packages/sklearn/externals/joblib/externals/cloudpickle/cloudpickle.py:47: DeprecationWarning: the imp module is deprecated in favour of importlib; see the module's documentation for alternative uses
  import imp
INFO:root:Namespace(accumulate=None, batch_size=32, dev_batch_size=8, epochs=2, gpu=True, log_interval=10, lr=6e-05, max_len=128, optimizer='adamw', seed=2, warmup_ratio=0.1)
[05:25:53] src/storage/storage.cc:135: Using GPUPooledRoundedStorageManager.
finetune_classifier.py:156: UserWarning: AdamW optimizer is not found. Please consider upgrading to mxnet>=1.5.0. Now the original Adam optimizer is used instead.
  warnings.warn("AdamW optimizer is not found. Please consider upgrading to "
INFO:root:[Epoch 0 Batch 10/119] loss=0.6278, lr=0.0000273, acc=0.694
INFO:root:[Epoch 0 Batch 20/119] loss=0.6702, lr=0.0000545, acc=0.671
INFO:root:[Epoch 0 Batch 30/119] loss=0.6198, lr=0.0000577, acc=0.679
INFO:root:[Epoch 0 Batch 40/119] loss=0.5713, lr=0.0000548, acc=0.688
INFO:root:[Epoch 0 Batch 50/119] loss=0.6247, lr=0.0000519, acc=0.686
INFO:root:[Epoch 0 Batch 60/119] loss=0.5648, lr=0.0000490, acc=0.693
INFO:root:[Epoch 0 Batch 70/119] loss=0.6059, lr=0.0000461, acc=0.693
INFO:root:[Epoch 0 Batch 80/119] loss=0.5700, lr=0.0000432, acc=0.698
INFO:root:[Epoch 0 Batch 90/119] loss=0.5926, lr=0.0000403, acc=0.697
INFO:root:[Epoch 0 Batch 100/119] loss=0.5487, lr=0.0000374, acc=0.701
INFO:root:[Epoch 0 Batch 110/119] loss=0.5446, lr=0.0000345, acc=0.705
INFO:root:Validation accuracy: 0.743
INFO:root:Time cost=120.4s
INFO:root:[Epoch 1 Batch 10/119] loss=0.5085, lr=0.0000290, acc=0.777
INFO:root:[Epoch 1 Batch 20/119] loss=0.5025, lr=0.0000261, acc=0.776
INFO:root:[Epoch 1 Batch 30/119] loss=0.4895, lr=0.0000232, acc=0.776
INFO:root:[Epoch 1 Batch 40/119] loss=0.4756, lr=0.0000203, acc=0.779
INFO:root:[Epoch 1 Batch 50/119] loss=0.4898, lr=0.0000174, acc=0.774
INFO:root:[Epoch 1 Batch 60/119] loss=0.4822, lr=0.0000145, acc=0.769
INFO:root:[Epoch 1 Batch 70/119] loss=0.4656, lr=0.0000116, acc=0.771
INFO:root:[Epoch 1 Batch 80/119] loss=0.5440, lr=0.0000087, acc=0.766
INFO:root:[Epoch 1 Batch 90/119] loss=0.4484, lr=0.0000058, acc=0.768
INFO:root:[Epoch 1 Batch 100/119] loss=0.4508, lr=0.0000029, acc=0.772
Traceback (most recent call last):
  File "finetune_classifier.py", line 227, in <module>
    train()
  File "finetune_classifier.py", line 210, in train
    gluon.utils.clip_global_norm(grads, 1)
  File "/home/ubuntu/mxnet_work/lib/python3.6/site-packages/mxnet/gluon/utils.py", line 148, in clip_global_norm
    if not np.isfinite(total_norm.asscalar()):
  File "/home/ubuntu/mxnet_work/lib/python3.6/site-packages/mxnet/ndarray/ndarray.py", line 2005, in asscalar
    return self.asnumpy()[0]
  File "/home/ubuntu/mxnet_work/lib/python3.6/site-packages/mxnet/ndarray/ndarray.py", line 1987, in asnumpy
    ctypes.c_size_t(data.size)))
  File "/home/ubuntu/mxnet_work/lib/python3.6/site-packages/mxnet/base.py", line 252, in check_call
    raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: [05:29:38] src/storage/./pooled_storage_manager.h:299: cudaMalloc failed: out of memory

Stack trace returned 10 entries:
[bt] (0) /home/ubuntu/mxnet_work/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x3ebbea) [0x7f4d40ae3bea]
[bt] (1) /home/ubuntu/mxnet_work/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x3ec211) [0x7f4d40ae4211]
[bt] (2) /home/ubuntu/mxnet_work/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x346d5fb) [0x7f4d43b655fb]
[bt] (3) /home/ubuntu/mxnet_work/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x3473e4e) [0x7f4d43b6be4e]
[bt] (4) /home/ubuntu/mxnet_work/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2db9616) [0x7f4d434b1616]
[bt] (5) /home/ubuntu/mxnet_work/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2db98c7) [0x7f4d434b18c7]
[bt] (6) /home/ubuntu/mxnet_work/lib/python3.6/site-packages/mxnet/libmxnet.so(mxnet::imperative::PushFCompute(std::function<void (nnvm::NodeAttrs const&, mxnet::OpContext const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, std::vector<mxnet::TBlob, std::allocator<mxnet::TBlob> > const&)> const&, nnvm::Op const*, nnvm::NodeAttrs const&, mxnet::Context const&, std::vector<mxnet::engine::Var*, std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::engine::Var*, std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::Resource, std::allocator<mxnet::Resource> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<unsigned int, std::allocator<unsigned int> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&)::{lambda(mxnet::RunContext)#1}::operator()(mxnet::RunContext) const+0x297) [0x7f4d434b1ce7]
[bt] (7) /home/ubuntu/mxnet_work/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2d1057d) [0x7f4d4340857d]
[bt] (8) /home/ubuntu/mxnet_work/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2d10567) [0x7f4d43408567]
[bt] (9) /home/ubuntu/mxnet_work/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2d10567) [0x7f4d43408567]

@haven-jeon
Copy link
Member

Wow, It's really fast to train. Great!

@eric-haibin-lin
Copy link
Member Author

@haven-jeon Thanks for pointing that out! I can reproduce the same OOM error on K80. If I do not use the MXNET_GPU_MEM_POOL_TYPE="Round" setting the memory usage will decrease by a lot. I added some warning message in the script for this option. Could you check if this is helpful?

@haven-jeon
Copy link
Member

haven-jeon commented Dec 27, 2018

@haven-jeon Thanks for pointing that out! I can reproduce the same OOM error on K80. If I do not use the MXNET_GPU_MEM_POOL_TYPE="Round" setting the memory usage will decrease by a lot. I added some warning message in the script for this option. Could you check if this is helpful?

@eric-haibin-lin
Works well! Thanks.

parser.add_argument('--max_len', type=int, default=128,
help='Maximum length of the sentence pairs, default is 128')
parser.add_argument('--seed', type=int, default=2, help='Random seed, default is 2')
parser.add_argument('--accumulate', type=int, default=None, help='The number of batches for '
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider adding a message for the actual effective batch size based on this option, when it's set.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

args.max_len, pad=False)
dev_trans = ClassificationTransform(tokenizer, MRPCDataset.get_labels(), args.max_len)

data_train = MRPCDataset('train').transform(train_trans, lazy=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider modularizing the data pipeline by making some parts into a function. this helps avoid leaking variables into global context.

# Set grad_req if gradient accumulation is required
if accumulate:
for p in model.collect_params().values():
p.grad_req = 'add'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems wrong... it's setting all parameters to add even for those that had grad_req = 'null'

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try

params = [p for p in model.collect_params().values() if p.grad_req != 'null']
if accumulate:
    for p in params:
        p.grad_req = 'add'

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. Constant params have grad_req=null and their req cannot be updated. But if user set some grad_req to null, those will be overridden. Updated now.

# update
if not accumulate or (batch_id + 1) % accumulate == 0:
grads = [p.grad(ctx) for p in params]
gluon.utils.clip_global_norm(grads, 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to be done properly. consider using the clip norm function from #470


if __name__ == '__main__':
pool_type = os.environ.get('MXNET_GPU_MEM_POOL_TYPE', '')
Copy link
Member

@szha szha Dec 28, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that static_alloc is on in the model, mem pool type would not be affecting model's internal spaces, but only the input and output arrays. In this case, you may want to tune the MXNET_GPU_MEM_POOL_ROUND_LINEAR_CUTOFF=x so that 2^x is on the same scale as the mode of the input output array sizes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be too much for users to understand how to set it properly... Would it be better to remove this option for now to keep the example simple?

@mli
Copy link
Member

mli commented Dec 28, 2018

Job PR-482/6 is complete.
Docs are uploaded to http:https://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-482/6/index.html

@mli
Copy link
Member

mli commented Dec 28, 2018

Job PR-482/8 is complete.
Docs are uploaded to http:https://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-482/8/index.html

@mli
Copy link
Member

mli commented Dec 30, 2018

Job PR-482/10 is complete.
Docs are uploaded to http:https://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-482/10/index.html

@eric-haibin-lin eric-haibin-lin merged commit 7681e03 into dmlc:master Dec 30, 2018
GluonNLP 0.6.0 automation moved this from In progress to Done Dec 30, 2018
@eric-haibin-lin eric-haibin-lin added this to Done in BERT Jan 3, 2019
@eric-haibin-lin eric-haibin-lin deleted the optimizer branch May 9, 2019 04:53
paperplanet pushed a commit to paperplanet/gluon-nlp that referenced this pull request Jun 9, 2019
…ing and gradient accumulation (dmlc#482)

* use adam_w and bucketing

* add missing code

* add helper msg for OOM error

* move optimizer def to gluonnlp. also fix padding token

* update documentation

* fix lint

* address CR comments

* revert unintended changes

* remove mem pool env var

* Revert "remove mem pool env var"

This reverts commit 60b8fdd.

* remove mem pool env var
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
No open projects
BERT
  
Done
Development

Successfully merging this pull request may close these issues.

None yet

5 participants