Skip to content

Commit

Permalink
add cpu embedding matrix to optimizer (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
adolphk-yk authored and ljshou committed May 27, 2019
1 parent 5ec5fd2 commit 7ca4ac7
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 5 deletions.
7 changes: 6 additions & 1 deletion LearningMachine.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,12 @@ def train(self, optimizer, loss_fn):
all_costs.append(loss.item())
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.conf.clip_grad_norm_max_norm)
if self.conf.clip_grad_norm_max_norm != -1:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.conf.clip_grad_norm_max_norm)
if isinstance(self.model, nn.DataParallel):
torch.nn.utils.clip_grad_norm_(self.model.module.layers['embedding'].get_parameters(), self.conf.clip_grad_norm_max_norm)
else:
torch.nn.utils.clip_grad_norm_(self.model.layers['embedding'].get_parameters(), self.conf.clip_grad_norm_max_norm)
optimizer.step()

del loss, logits, logits_softmax, logits_flat
Expand Down
1 change: 1 addition & 0 deletions Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def __init__(self, conf, problem, vocab_info, use_gpu):
for input_cluster in emb_conf:
emb_conf[input_cluster]['dim'] = layer_arch['conf'][input_cluster]['dim']
emb_conf[input_cluster]['fix_weight'] = layer_arch['conf'][input_cluster].get('fix_weight', False)
emb_conf[input_cluster]['weight_on_gpu'] = layer_arch['conf'][input_cluster].get('weight_on_gpu', True)

all_layer_configs[EMBED_LAYER_ID] = get_conf(EMBED_LAYER_ID, layer_arch['layer'],
None, all_layer_configs, inputs, self.use_gpu, conf_dict={'conf': emb_conf},
Expand Down
3 changes: 2 additions & 1 deletion ModelConf.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,8 @@ def load_from_file(self, conf_path):
if self.phase == 'train':
self.optimizer_name = self.get_item(['training_params', 'optimizer', 'name'])
self.optimizer_params = self.get_item(['training_params', 'optimizer', 'params'])
self.clip_grad_norm_max_norm = self.get_item(['training_params', 'clip_grad_norm_max_norm'], default=5)

self.clip_grad_norm_max_norm = self.get_item(['training_params', 'clip_grad_norm_max_norm'], default=-1)

if hasattr(self.params, 'learning_rate') and self.params.learning_rate:
self.optimizer_params['lr'] = self.params.learning_rate
Expand Down
17 changes: 16 additions & 1 deletion block_zoo/Embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,13 @@ def __init__(self, layer_conf):
if 'init_weights' in layer_conf.conf[input_cluster] and layer_conf.conf[input_cluster]['init_weights'] is not None:
self.embeddings[input_cluster].weight = nn.Parameter(torch.from_numpy(layer_conf.conf[input_cluster]['init_weights']))

# judge the embedding matrix weight's device
if layer_conf.conf[input_cluster]['weight_on_gpu']:
self.embeddings[input_cluster].to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
logging.info("The embeddings[%s]'s weight is on GPU now, you can modify the weight_on_gpu parameter to change embeddings weight device" % input_cluster)
else:
logging.info(
"The embeddings[%s]'s weight is on cpu now, you can modify the weight_on_gpu parameter to change embeddings weight device" % input_cluster)
# judge if fix the embedding weight
if layer_conf.conf[input_cluster]['fix_weight']:
self.embeddings[input_cluster].weight.requires_grad = False
Expand Down Expand Up @@ -160,7 +167,10 @@ def forward(self, inputs, use_gpu=False):
# emb = self.embeddings[input_cluster](input, lengths[input]).float()
# else:
# emb = self.embeddings[input_cluster](input).float()
emb = self.embeddings[input_cluster](input.cpu()).float()
if self.embeddings[input_cluster].weight.device.type == 'cpu':
emb = self.embeddings[input_cluster](input.cpu()).float()
else:
emb = self.embeddings[input_cluster](input).float()
if use_gpu is True:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
emb = emb.to(device)
Expand All @@ -171,6 +181,11 @@ def forward(self, inputs, use_gpu=False):
else:
return features[0]

def get_parameters(self):
for sub_emb in self.embeddings:
for param in self.embeddings[sub_emb].parameters():
yield param




6 changes: 5 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import copy

import torch
import torch.nn as nn
from ModelConf import ModelConf
from problem import Problem
from utils.common_utils import dump_to_pkl, load_from_pkl, prepare_dir
Expand Down Expand Up @@ -231,7 +232,10 @@ def main(params):
loss_fn.cuda()

### optimizer
optimizer = eval(conf.optimizer_name)(lm.model.parameters(), **conf.optimizer_params)
if isinstance(lm.model, nn.DataParallel):
optimizer = eval(conf.optimizer_name)(list(lm.model.parameters()) + list(lm.model.module.layers['embedding'].get_parameters()), **conf.optimizer_params)
else:
optimizer = eval(conf.optimizer_name)(list(lm.model.parameters()) + list(lm.model.layers['embedding'].get_parameters()), **conf.optimizer_params)

## train
lm.train(optimizer, loss_fn)
Expand Down
8 changes: 7 additions & 1 deletion utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import pickle as pkl
import torch
import torch.nn as nn
import os
import shutil
import time
Expand Down Expand Up @@ -58,7 +59,12 @@ def get_trainable_param_num(model):
Returns:
"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
if isinstance(model, nn.DataParallel):
model_param = list(model.parameters()) + list(model.module.layers['embedding'].get_parameters())
else:
model_param = list(model.parameters()) + list(model.layers['embedding'].get_parameters())

return sum(p.numel() for p in model_param if p.requires_grad)


def get_param_num(model):
Expand Down

0 comments on commit 7ca4ac7

Please sign in to comment.