Skip to content

Commit

Permalink
Add more models to benchmark_score (apache#12780)
Browse files Browse the repository at this point in the history
* add models to cnn benchmark

* improve benchmark score

* add benchmark_gluon

* improve lint

* improve lint

* add licsence for script

* improve script lint

* mv benchmark_gluon to new location

* support multi-gpus

* Add a new parameter 'global batchsize' for the batch size multiplication for multi-gpu case

* add batch size argument help

* improve help and change default batchsize

* simplify benchmark_gluon
  • Loading branch information
xinyu-intel authored and Jose Luis Contreras committed Nov 13, 2018
1 parent 352f642 commit e6d08c9
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 9 deletions.
164 changes: 164 additions & 0 deletions benchmark/python/gluon/benchmark_gluon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http:https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import mxnet as mx
import mxnet.gluon.model_zoo.vision as models
import time
import logging
import argparse
import subprocess
import os
import errno

logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(description='Gluon modelzoo-based CNN performance benchmark')

parser.add_argument('--model', type=str, default='all',
choices=['all', 'alexnet', 'densenet121', 'densenet161',
'densenet169', 'densenet201', 'inceptionv3', 'mobilenet0.25',
'mobilenet0.5', 'mobilenet0.75', 'mobilenet1.0', 'mobilenetv2_0.25',
'mobilenetv2_0.5', 'mobilenetv2_0.75', 'mobilenetv2_1.0', 'resnet101_v1',
'resnet101_v2', 'resnet152_v1', 'resnet152_v2', 'resnet18_v1',
'resnet18_v2', 'resnet34_v1', 'resnet34_v2', 'resnet50_v1',
'resnet50_v2', 'squeezenet1.0', 'squeezenet1.1', 'vgg11',
'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
'vgg19', 'vgg19_bn'])
parser.add_argument('--batch-size', type=int, default=0,
help='Batch size to use for benchmarking. Example: 32, 64, 128.'
'By default, runs benchmark for batch sizes - 1, 32, 64, 128, 256')
parser.add_argument('--num-batches', type=int, default=10)
parser.add_argument('--gpus', type=str, default='',
help='GPU IDs to use for this benchmark task. Example: --gpus=0,1,2,3 to use 4 GPUs.'
'By default, use CPU only.')
parser.add_argument('--type', type=str, default='inference', choices=['all', 'training', 'inference'])

opt = parser.parse_args()

num_batches = opt.num_batches
dry_run = 10 # use 10 iterations to warm up
batch_inf = [1, 32, 64, 128, 256]
batch_train = [1, 32, 64, 128, 256]
image_shapes = [(3, 224, 224), (3, 299, 299)]

def score(network, batch_size, ctx):
assert (batch_size >= len(ctx)), "ERROR: batch size should not be smaller than num of GPUs."
net = models.get_model(network)
if 'inceptionv3' == network:
data_shape = [('data', (batch_size,) + image_shapes[1])]
else:
data_shape = [('data', (batch_size,) + image_shapes[0])]

data = mx.sym.var('data')
out = net(data)
softmax = mx.sym.SoftmaxOutput(out, name='softmax')
mod = mx.mod.Module(softmax, context=ctx)
mod.bind(for_training = False,
inputs_need_grad = False,
data_shapes = data_shape)
mod.init_params(initializer=mx.init.Xavier(magnitude=2.))
data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=ctx[0]) for _, shape in mod.data_shapes]
batch = mx.io.DataBatch(data, [])
for i in range(dry_run + num_batches):
if i == dry_run:
tic = time.time()
mod.forward(batch, is_train=False)
for output in mod.get_outputs():
output.wait_to_read()
fwd = time.time() - tic
return fwd


def train(network, batch_size, ctx):
assert (batch_size >= len(ctx)), "ERROR: batch size should not be smaller than num of GPUs."
net = models.get_model(network)
if 'inceptionv3' == network:
data_shape = [('data', (batch_size,) + image_shapes[1])]
else:
data_shape = [('data', (batch_size,) + image_shapes[0])]

data = mx.sym.var('data')
out = net(data)
softmax = mx.sym.SoftmaxOutput(out, name='softmax')
mod = mx.mod.Module(softmax, context=ctx)
mod.bind(for_training = True,
inputs_need_grad = False,
data_shapes = data_shape)
mod.init_params(initializer=mx.init.Xavier(magnitude=2.))
if len(ctx) > 1:
mod.init_optimizer(kvstore='device', optimizer='sgd')
else:
mod.init_optimizer(kvstore='local', optimizer='sgd')
data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=ctx[0]) for _, shape in mod.data_shapes]
batch = mx.io.DataBatch(data, [])
for i in range(dry_run + num_batches):
if i == dry_run:
tic = time.time()
mod.forward(batch, is_train=True)
for output in mod.get_outputs():
output.wait_to_read()
mod.backward()
mod.update()
bwd = time.time() - tic
return bwd

if __name__ == '__main__':
runtype = opt.type
bs = opt.batch_size

if opt.model == 'all':
networks = ['alexnet', 'densenet121', 'densenet161', 'densenet169', 'densenet201',
'inceptionv3', 'mobilenet0.25', 'mobilenet0.5', 'mobilenet0.75',
'mobilenet1.0', 'mobilenetv2_0.25', 'mobilenetv2_0.5', 'mobilenetv2_0.75',
'mobilenetv2_1.0', 'resnet101_v1', 'resnet101_v2', 'resnet152_v1', 'resnet152_v2',
'resnet18_v1', 'resnet18_v2', 'resnet34_v1', 'resnet34_v2', 'resnet50_v1',
'resnet50_v2', 'squeezenet1.0', 'squeezenet1.1', 'vgg11', 'vgg11_bn', 'vgg13',
'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn']
logging.info('It may take some time to run all models, '
'set --network to run a specific one')
else:
networks = [opt.model]

devs = [mx.gpu(int(i)) for i in opt.gpus.split(',')] if opt.gpus.strip() else [mx.cpu()]
num_gpus = len(devs)

for network in networks:
logging.info('network: %s', network)
logging.info('device: %s', devs)
if runtype == 'inference' or runtype == 'all':
if bs != 0:
fwd_time = score(network, bs, devs)
fps = (bs * num_batches)/fwd_time
logging.info(network + ' inference perf for BS %d is %f img/s', bs, fps)
else:
logging.info('run batchsize [1, 2, 4, 8, 16, 32] by default, '
'set --batch-size to run a specific one')
for batch_size in batch_inf:
fwd_time = score(network, batch_size, devs)
fps = (batch_size * num_batches) / fwd_time
logging.info(network + ' inference perf for BS %d is %f img/s', batch_size, fps)
if runtype == 'training' or runtype == 'all':
if bs != 0:
bwd_time = train(network, bs, devs)
fps = (bs * num_batches) / bwd_time
logging.info(network + ' training perf for BS %d is %f img/s', bs, fps)
else:
logging.info('run batchsize [1, 2, 4, 8, 16, 32] by default, '
'set --batch-size to run a specific one')
for batch_size in batch_train:
bwd_time = train(network, batch_size, devs)
fps = (batch_size * num_batches) / bwd_time
logging.info(network + ' training perf for BS %d is %f img/s', batch_size, fps)
58 changes: 49 additions & 9 deletions example/image-classification/benchmark_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,49 @@
from common import find_mxnet
from common.util import get_gpus
import mxnet as mx
import mxnet.gluon.model_zoo.vision as models
from importlib import import_module
import logging
import argparse
import time
import numpy as np
logging.basicConfig(level=logging.DEBUG)

parser = argparse.ArgumentParser(description='SymbolAPI-based CNN inference performance benchmark')
parser.add_argument('--network', type=str, default='all',
choices=['all', 'alexnet', 'vgg-16', 'resnetv1-50', 'resnet-50',
'resnet-152', 'inception-bn', 'inception-v3',
'inception-v4', 'inception-resnet-v2', 'mobilenet',
'densenet121', 'squeezenet1.1'])
parser.add_argument('--batch-size', type=int, default=0,
help='Batch size to use for benchmarking. Example: 32, 64, 128.'
'By default, runs benchmark for batch sizes - 1, 32, 64, 128, 256')

opt = parser.parse_args()

def get_symbol(network, batch_size, dtype):
image_shape = (3,299,299) if network == 'inception-v3' else (3,224,224)
image_shape = (3,299,299) if network in ['inception-v3', 'inception-v4'] else (3,224,224)
num_layers = 0
if 'resnet' in network:
if network == 'inception-resnet-v2':
network = network
elif 'resnet' in network:
num_layers = int(network.split('-')[1])
network = network.split('-')[0]
if 'vgg' in network:
num_layers = int(network.split('-')[1])
network = 'vgg'
net = import_module('symbols.'+network)
sym = net.get_symbol(num_classes=1000,
image_shape=','.join([str(i) for i in image_shape]),
num_layers=num_layers,
dtype=dtype)
if network in ['densenet121', 'squeezenet1.1']:
sym = models.get_model(network)
sym.hybridize()
data = mx.sym.var('data')
sym = sym(data)
sym = mx.sym.SoftmaxOutput(sym, name='softmax')
else:
net = import_module('symbols.'+network)
sym = net.get_symbol(num_classes=1000,
image_shape=','.join([str(i) for i in image_shape]),
num_layers=num_layers,
dtype=dtype)
return (sym, [('data', (batch_size,)+image_shape)])

def score(network, dev, batch_size, num_batches, dtype):
Expand Down Expand Up @@ -69,14 +92,31 @@ def score(network, dev, batch_size, num_batches, dtype):
return num_batches*batch_size/(time.time() - tic)

if __name__ == '__main__':
networks = ['alexnet', 'vgg-16', 'inception-bn', 'inception-v3', 'resnetv1-50', 'resnet-50', 'resnet-152']
if opt.network == 'all':
networks = ['alexnet', 'vgg-16', 'resnetv1-50', 'resnet-50',
'resnet-152', 'inception-bn', 'inception-v3',
'inception-v4', 'inception-resnet-v2',
'mobilenet', 'densenet121', 'squeezenet1.1']
logging.info('It may take some time to run all models, '
'set --network to run a specific one')
else:
networks = [opt.network]
devs = [mx.gpu(0)] if len(get_gpus()) > 0 else []
# Enable USE_MKLDNN for better CPU performance
devs.append(mx.cpu())

batch_sizes = [1, 2, 4, 8, 16, 32]
if opt.batch_size == 0:
batch_sizes = [1, 32, 64, 128, 256]
logging.info('run batchsize [1, 32, 64, 128, 256] by default, '
'set --batch-size to run a specific one')
else:
batch_sizes = [opt.batch_size]

for net in networks:
logging.info('network: %s', net)
if net in ['densenet121', 'squeezenet1.1']:
logging.info('network: %s is converted from gluon modelzoo', net)
logging.info('you can run benchmark/python/gluon/benchmark_gluon.py for more models')
for d in devs:
logging.info('device: %s', d)
logged_fp16_warning = False
Expand Down

0 comments on commit e6d08c9

Please sign in to comment.