Skip to content

Commit

Permalink
adding gpu device option (MouseLand#389)
Browse files Browse the repository at this point in the history
  • Loading branch information
carsen-stringer committed Apr 6, 2022
1 parent 36fa3ff commit d99b0f2
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
3 changes: 2 additions & 1 deletion cellpose/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def main():
# settings for CPU vs GPU
hardware_args = parser.add_argument_group("hardware arguments")
hardware_args.add_argument('--use_gpu', action='store_true', help='use gpu if torch with cuda installed')
hardware_args.add_argument('--gpu_device', required=False, default=0, type=int, help='which gpu device to use')
hardware_args.add_argument('--check_mkl', action='store_true', help='check if mkl working')

# settings for locating and formatting images
Expand Down Expand Up @@ -159,7 +160,7 @@ def main():
if not (args.train or args.train_size):
saving_something = args.save_png or args.save_tif or args.save_flows or args.save_ncolor or args.save_txt

device, gpu = models.assign_device(use_torch=True, gpu=args.use_gpu)
device, gpu = models.assign_device(use_torch=True, gpu=args.use_gpu, device=args.gpu_device)

if args.pretrained_model is None or args.pretrained_model == 'None' or args.pretrained_model == 'False' or args.pretrained_model == '0':
pretrained_model = False
Expand Down
8 changes: 3 additions & 5 deletions cellpose/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
from torch.utils import mkldnn as mkldnn_utils
from . import resnet_torch
TORCH_ENABLED = True
torch_GPU = torch.device('cuda')
torch_CPU = torch.device('cpu')

core_logger = logging.getLogger(__name__)
tqdm_out = utils.TqdmToLogger(core_logger, level=logging.INFO)
Expand Down Expand Up @@ -64,13 +62,13 @@ def _use_gpu_torch(gpu_number=0):
core_logger.info('TORCH CUDA version not installed/working.')
return False

def assign_device(use_torch=True, gpu=False):
def assign_device(use_torch=True, gpu=False, device=0):
if gpu and use_gpu(use_torch=True):
device = torch_GPU
device = torch.device(f'cuda:{device}')
gpu=True
core_logger.info('>>>> using GPU')
else:
device = torch_CPU
device = torch.device('cpu')
core_logger.info('>>>> using CPU')
gpu=False
return device, gpu
Expand Down
2 changes: 2 additions & 0 deletions docs/command.rst
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ You can run the help string and see all the options:

hardware arguments:
--use_gpu use gpu if torch with cuda installed
--gpu_device GPU_DEVICE
which gpu device to use. Default: 0
--check_mkl check if mkl working

input image arguments:
Expand Down

0 comments on commit d99b0f2

Please sign in to comment.