Skip to content

Commit

Permalink
修复gradient accumulation
Browse files Browse the repository at this point in the history
  • Loading branch information
Morizeyao committed Nov 20, 2019
1 parent 2ed6a8d commit 44d8bc6
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
## 项目状态

- 目前项目主要架构已经稳定。如发现任何bug或是有功能意见与改进欢迎提交Issue,PR或是联系作者。
- 如使用梯度积累,loss计算可能存在bug。

## 使用方法

Expand Down Expand Up @@ -61,7 +60,7 @@ python ./generate.py --length=50 --nsamples=4 --prefix=xxx --fast_pattern --save

## FP16与Gradient Accumulation支持

- 我在train.py文件中加入了fp16与gradient accumulation支持,如果你安装了apex并且知道fp16是什么的话,可以修改变量fp16=True来启用。但是目前fp16不收敛,原因不明。
- 我在train.py文件中加入了fp16与gradient accumulation支持,如果你安装了apex并且知道fp16是什么的话,可以修改变量fp16=True来启用。但是目前fp16可能不收敛,原因不明。

## 联系作者

Expand Down
10 changes: 10 additions & 0 deletions config/model_config_test.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"initializer_range": 0.02,
"layer_norm_epsilon": 1e-05,
"n_ctx": 64,
"n_embd": 128,
"n_head": 2,
"n_layer": 1,
"n_positions": 64,
"vocab_size": 13317
}
9 changes: 4 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,23 +207,22 @@ def main():
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

# optimizer step
if (step + 1) % gradient_accumulation == 0:
if (overall_step + 1) % gradient_accumulation == 0:
running_loss += loss.item()
optimizer.step()
optimizer.zero_grad()
scheduler.step()
overall_step += 1
if (overall_step + 1) % log_step == 0:
tb_writer.add_scalar('loss', loss.item(), overall_step)
if (overall_step + 1) % log_step == 0:
tb_writer.add_scalar('loss', loss.item() * gradient_accumulation, overall_step)
print('now time: {}:{}. Step {} of piece {} of epoch {}, loss {}'.format(
datetime.now().hour,
datetime.now().minute,
step + 1,
piece_num,
epoch + 1,
running_loss * gradient_accumulation / log_step))
running_loss * gradient_accumulation / (log_step / gradient_accumulation)))
running_loss = 0
overall_step += 1
piece_num += 1

print('saving model for epoch {}'.format(epoch + 1))
Expand Down

0 comments on commit 44d8bc6

Please sign in to comment.