Skip to content

Instantly share code, notes, and snippets.

@psobolewskiPhD
Created December 12, 2022 16:41
Show Gist options
  • Save psobolewskiPhD/80d4b18ba78b49eec6a2398153347a83 to your computer and use it in GitHub Desktop.
Save psobolewskiPhD/80d4b18ba78b49eec6a2398153347a83 to your computer and use it in GitHub Desktop.
Test script for cellpose MPS branch
# %%
%env PYTORCH_ENABLE_MPS_FALLBACK=1
# %%
import numpy as np
import time, os, sys
from urllib.parse import urlparse
import skimage.io
import matplotlib.pyplot as plt
import matplotlib as mpl
%matplotlib inline
mpl.rcParams['figure.dpi'] = 300
from urllib.parse import urlparse
from cellpose import models, core
# %%
use_GPU = core.use_gpu()
print(f'>>> GPU activated? {use_GPU}')
# call logger_setup to have output of cellpose written
from cellpose.io import logger_setup
logger_setup();
# %%
from cellpose import utils
# I will download images from website
urls = ['https://www.cellpose.org/static/images/img02.png',
'https://www.cellpose.org/static/images/img03.png',
'https://www.cellpose.org/static/images/img05.png',
'https://www.cellpose.org/static/data/rgb_3D.tif']
files = []
for url in urls:
parts = urlparse(url)
filename = os.path.basename(parts.path)
if not os.path.exists(filename):
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, filename))
utils.download_url_to_file(url, filename)
files.append(filename)
# REPLACE FILES WITH YOUR IMAGE PATHS
# files = ['img0.tif', 'img1.tif']
imgs = [skimage.io.imread(f) for f in files]
nimg = len(imgs)
plt.figure(figsize=(8,4))
for k,img in enumerate(imgs[:-1]):
plt.subplot(1,3,k+1)
plt.imshow(img)
# %%
# RUN CELLPOSE
imgs_2D = imgs[:-1]
#from skimage.data import coins, gravel
#imgs_2D = [gravel()]
# %%
# DEFINE CELLPOSE MODEL
# model_type='cyto' or model_type='nuclei'
model = models.Cellpose(gpu=True, model_type='cyto')
# %%
# define CHANNELS to run segementation on
# grayscale=0, R=1, G=2, B=3
# channels = [cytoplasm, nucleus]
# if NUCLEUS channel does not exist, set the second channel to 0
# channels = [0,0]
# IF ALL YOUR IMAGES ARE THE SAME TYPE, you can give a list with 2 elements
# channels = [0,0] # IF YOU HAVE GRAYSCALE
# channels = [2,3] # IF YOU HAVE G=cytoplasm and B=nucleus
# channels = [2,1] # IF YOU HAVE G=cytoplasm and R=nucleus
# or if you have different types of channels in each image
channels = [[2,3], [0,0], [0,0]]
#channels = [0,0]
# if diameter is set to None, the size of the cells is estimated on a per image basis
# you can set the average cell `diameter` in pixels yourself (recommended)
# diameter can be a list or a single number for all images
masks, flows, styles, diams = model.eval(imgs_2D, flow_threshold=None, channels=channels)
# %%
from cellpose import plot
nimg = len(imgs_2D)
for idx in range(nimg):
maski = masks[idx]
flowi = flows[idx][0]
fig = plt.figure(figsize=(12,5))
plot.show_segmentation(fig, imgs_2D[idx], maski, flowi, channels=channels)
plt.tight_layout()
plt.show()
# %%