Skip to content

Commit

Permalink
adding imagecodecs dependency and reactivating 3D drawing and adding …
Browse files Browse the repository at this point in the history
…other gui features (MouseLand#493)
  • Loading branch information
carsen-stringer committed Apr 30, 2022
1 parent 873641e commit 8d5f378
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 41 deletions.
104 changes: 70 additions & 34 deletions cellpose/gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def __init__(self, image=None):

self.setStyleSheet("QMainWindow {background: 'black';}")
self.stylePressed = ("QPushButton {Text-align: left; "
"background-color: rgb(100,50,100); "
"background-color: rgb(150,50,150); "
"border-color: white;"
"color:white;}")
self.styleUnpressed = ("QPushButton {Text-align: left; "
Expand Down Expand Up @@ -411,7 +411,7 @@ def make_buttons(self):
self.scale_on = True
self.ScaleOn = QCheckBox('scale disk on')
self.ScaleOn.setFont(self.medfont)
self.ScaleOn.setStyleSheet('color: red;')
self.ScaleOn.setStyleSheet('color: rgb(150,50,150);')
self.ScaleOn.setChecked(True)
self.ScaleOn.setToolTip('see current diameter as red disk at bottom')
self.ScaleOn.toggled.connect(self.toggle_scale)
Expand Down Expand Up @@ -468,7 +468,7 @@ def make_buttons(self):
self.flow_threshold = QLineEdit()
self.flow_threshold.setText('0.4')
self.flow_threshold.returnPressed.connect(self.compute_cprob)
self.flow_threshold.setFixedWidth(60)
self.flow_threshold.setFixedWidth(70)
self.l0.addWidget(self.flow_threshold, b,5,1,4)

b+=1
Expand All @@ -480,9 +480,22 @@ def make_buttons(self):
self.cellprob_threshold = QLineEdit()
self.cellprob_threshold.setText('0.0')
self.cellprob_threshold.returnPressed.connect(self.compute_cprob)
self.cellprob_threshold.setFixedWidth(60)
self.cellprob_threshold.setFixedWidth(70)
self.l0.addWidget(self.cellprob_threshold, b,5,1,4)

b+=1
label = QLabel('stitch_threshold:')
label.setToolTip('for 3D volumes, turn on stitch_threshold to stitch masks across planes instead of running cellpose in 3D (see docs for details)')
label.setStyleSheet(label_style)
label.setFont(self.medfont)
self.l0.addWidget(label, b, 0,1,5)
self.stitch_threshold = QLineEdit()
self.stitch_threshold.setText('0.0')
#self.cellprob_threshold.returnPressed.connect(self.compute_cprob)
self.stitch_threshold.setFixedWidth(70)
self.l0.addWidget(self.stitch_threshold, b,5,1,4)


b+=1
self.GB = QGroupBox('model zoo')
self.GB.setStyleSheet("QGroupBox { border: 1px solid white; color:white; padding: 10px 0px;}")
Expand Down Expand Up @@ -559,6 +572,13 @@ def make_buttons(self):
self.progress.setStyleSheet('color: gray;')
self.l0.addWidget(self.progress, b,0,1,9)

b+=1
self.roi_count = QLabel('0 ROIs')
self.roi_count.setStyleSheet('color: white;')
self.roi_count.setFont(self.boldfont)
self.roi_count.setAlignment(QtCore.Qt.AlignRight)
self.l0.addWidget(self.roi_count, b,0,1,9)

b+=1
line = QHLine()
line.setStyleSheet('color: white;')
Expand Down Expand Up @@ -674,10 +694,12 @@ def keyPressEvent(self, event):
if self.NZ>1:
if event.key() == QtCore.Qt.Key_Left:
self.currentZ = max(0,self.currentZ-1)
self.zpos.setText(str(self.currentZ))
self.scroll.setValue(self.currentZ)
updated = True
elif event.key() == QtCore.Qt.Key_Right:
self.currentZ = min(self.NZ-1, self.currentZ+1)
self.zpos.setText(str(self.currentZ))
self.scroll.setValue(self.currentZ)
updated = True
else:
if event.key() == QtCore.Qt.Key_X:
self.MCheckBox.toggle()
Expand Down Expand Up @@ -894,6 +916,7 @@ def move_in_Z(self):
self.currentZ = min(self.NZ, max(0, int(self.scroll.value())))
self.zpos.setText(str(self.currentZ))
self.update_plot()
self.draw_layer()
self.update_layer()


Expand Down Expand Up @@ -1037,7 +1060,6 @@ def reset(self):
self.flows = [[],[],[],[],[[]]]
self.stack = np.zeros((1,self.Ly,self.Lx,3))
# masks matrix
self.layers = 0*np.ones((1,self.Ly,self.Lx,4), np.uint8)
self.layerz = 0*np.ones((self.Ly,self.Lx,4), np.uint8)
# image matrix with a scale disk
self.radii = 0*np.ones((self.Ly,self.Lx,4), np.uint8)
Expand Down Expand Up @@ -1066,8 +1088,6 @@ def autosave_on(self):
def clear_all(self):
self.prev_selected = 0
self.selected = 0
#self.layers_undo, self.cellpix_undo, self.outpix_undo = [],[],[]
self.layers = 0*np.ones((self.NZ,self.Ly,self.Lx,4), np.uint8)
self.layerz = 0*np.ones((self.Ly,self.Lx,4), np.uint8)
self.cellpix = np.zeros((self.NZ,self.Ly,self.Lx), np.uint32)
self.outpix = np.zeros((self.NZ,self.Ly,self.Lx), np.uint32)
Expand All @@ -1082,8 +1102,6 @@ def select_cell(self, idx):
if self.selected > 0:
z = self.currentZ
self.layerz[self.cellpix[z]==idx] = np.array([255,255,255,self.opacity])
#if self.outlinesOn:
# self.layers[self.outpix==idx] = np.array(self.outcolor)
self.update_layer()

def unselect_cell(self):
Expand Down Expand Up @@ -1191,13 +1209,13 @@ def undo_remove_cell(self):
self.redo.setEnabled(False)


def remove_stroke(self, delete_points=True):
def remove_stroke(self, delete_points=True, stroke_ind=-1):
#self.current_stroke = get_unique_points(self.current_stroke)
stroke = np.array(self.strokes[-1])
inZ = stroke[0,0]==self.currentZ
stroke = np.array(self.strokes[stroke_ind])
cZ = self.currentZ
inZ = stroke[0,0]==cZ
if inZ:
outpix = self.outpix[self.currentZ, stroke[:,1],stroke[:,2]]>0
outpix = self.outpix[cZ, stroke[:,1],stroke[:,2]]>0
self.layerz[stroke[~outpix,1],stroke[~outpix,2]] = np.array([0,0,0,0])
cellpix = self.cellpix[cZ, stroke[:,1], stroke[:,2]]
ccol = self.cellcolors.copy()
Expand All @@ -1213,8 +1231,9 @@ def remove_stroke(self, delete_points=True):
self.layerz[stroke[outpix,1],stroke[outpix,2]] = np.array(self.outcolor)
if delete_points:
self.current_point_set = self.current_point_set[:-1*(stroke[:,-1]==1).sum()]
del self.strokes[-1]
self.update_layer()

del self.strokes[stroke_ind]

def plot_clicked(self, event):
if event.button()==QtCore.Qt.LeftButton and (event.modifiers() != QtCore.Qt.ShiftModifier and
Expand Down Expand Up @@ -1312,11 +1331,15 @@ def update_plot(self):

def update_layer(self):
if self.masksOn or self.outlinesOn:
self.draw_layer()
#self.draw_layer()
self.layer.setImage(self.layerz, autoLevels=False)
self.update_roi_count()
self.win.show()
self.show()

def update_roi_count(self):
self.roi_count.setText(f'{self.ncells} ROIs')

def update_ortho(self):
if self.NZ>1 and self.orthobtn.isChecked():
dzcurrent = self.dz
Expand Down Expand Up @@ -1486,14 +1509,17 @@ def draw_mask(self, z, ar, ac, vr, vc, color, idx=None):
def compute_scale(self):
self.diameter = float(self.Diameter.text())
self.pr = int(float(self.Diameter.text()))
radii = np.zeros((self.Ly+self.pr,self.Lx), np.uint8)
self.radii = np.zeros((self.Ly+self.pr,self.Lx,4), np.uint8)
yy,xx = disk([self.Ly+self.pr/2-1, self.pr/2+1],
self.pr/2, self.Ly+self.pr, self.Lx)
self.radii[yy,xx,0] = 255
self.radii[yy,xx,-1] = 255#self.opacity * (radii>0)
self.radii_padding = int(self.pr*1.25)
self.radii = np.zeros((self.Ly+self.radii_padding,self.Lx,4), np.uint8)
yy,xx = disk([self.Ly+self.radii_padding/2-1, self.pr/2+1],
self.pr/2, self.Ly+self.radii_padding, self.Lx)
# rgb(150,50,150)
self.radii[yy,xx,0] = 150
self.radii[yy,xx,1] = 50
self.radii[yy,xx,2] = 150
self.radii[yy,xx,3] = 255
self.update_plot()
self.p0.setYRange(0,self.Ly+self.pr)
self.p0.setYRange(0,self.Ly+self.radii_padding)
self.p0.setXRange(0,self.Lx)
self.win.show()
self.show()
Expand All @@ -1506,10 +1532,18 @@ def draw_masks(self):

def draw_layer(self):
if self.masksOn:
self.layerz = np.zeros((self.Ly,self.Lx,4), np.uint8)
self.layerz[...,:3] = self.cellcolors[self.cellpix[self.currentZ],:]
self.layerz[...,3] = self.opacity * (self.cellpix[self.currentZ]>0).astype(np.uint8)
if self.selected>0:
self.layerz[self.cellpix[self.currentZ]==self.selected] = np.array([255,255,255,self.opacity])
cZ = self.currentZ
stroke_z = np.array([s[0][0] for s in self.strokes])
inZ = np.nonzero(stroke_z == cZ)[0]
if len(inZ) > 0:
for i in inZ:
stroke = np.array(self.strokes[i])
self.layerz[stroke[:,1], stroke[:,2]] = np.array([255,0,255,100])
else:
self.layerz[...,3] = 0

Expand Down Expand Up @@ -1563,11 +1597,6 @@ def initialize_model(self, model_name=None):

def add_model(self):
io._add_model(self)
#a_list = ["abc", "def", "ghi"]
#textfile = open("a_file.txt", "w")
#for element in a_list:
# textfile.write(element + "\n")
#textfile.close()
return

def remove_model(self):
Expand Down Expand Up @@ -1685,8 +1714,11 @@ def compute_model(self, model_name=None):
self.initialize_model(model_name)
self.progress.setValue(10)
do_3D = False
stitch_threshold = False
if self.NZ > 1:
do_3D = True
stitch_threshold = float(self.stitch_threshold.text())
stitch_threshold = 0 if stitch_threshold <= 0 or stitch_threshold > 1 else stitch_threshold
do_3D = True if stitch_threshold==0 else False
data = self.stack.copy()
else:
data = self.stack[0].copy()
Expand All @@ -1699,7 +1731,8 @@ def compute_model(self, model_name=None):
diameter=self.diameter,
cellprob_threshold=cellprob_threshold,
flow_threshold=flow_threshold,
do_3D=do_3D,
do_3D=do_3D,
stitch_threshold=stitch_threshold,
progress=self.progress)[:2]
except Exception as e:
print('NET ERROR: %s'%e)
Expand All @@ -1712,12 +1745,12 @@ def compute_model(self, model_name=None):
# flows = flows[0]
self.flows[0] = flows[0].copy() #RGB flow
self.flows[1] = (np.clip(normalize99(flows[2].copy()), 0, 1) * 255).astype(np.uint8) #dist/prob
if not do_3D:
if not do_3D and not stitch_threshold > 0:
masks = masks[np.newaxis,...]
self.flows[0] = resize_image(self.flows[0], masks.shape[-2], masks.shape[-1],
interpolation=cv2.INTER_NEAREST)
self.flows[1] = resize_image(self.flows[1], masks.shape[-2], masks.shape[-1])
if not do_3D:
if not do_3D and not stitch_threshold > 0:
self.flows[2] = np.zeros(masks.shape[1:], dtype=np.uint8)
self.flows = [self.flows[n][np.newaxis,...] for n in range(len(self.flows))]
else:
Expand All @@ -1737,8 +1770,10 @@ def compute_model(self, model_name=None):
io._masks_to_gui(self, masks, outlines=None)
self.progress.setValue(100)

if not do_3D:
if not do_3D and not stitch_threshold > 0:
self.recompute_masks = True
else:
self.recompute_masks = False
except Exception as e:
print('ERROR: %s'%e)

Expand All @@ -1750,6 +1785,7 @@ def enable_buttons(self):
self.StyleButtons[i].setEnabled(True)
self.StyleButtons[i].setStyleSheet(self.styleUnpressed)
self.SizeButton.setEnabled(True)
self.SCheckBox.setEnabled(True)
self.SizeButton.setStyleSheet(self.styleUnpressed)
self.newmodel.setEnabled(True)
self.loadMasks.setEnabled(True)
Expand Down
7 changes: 5 additions & 2 deletions cellpose/gui/guiparts.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def __init__(self, parent, model_name, text):
self.model_name = model_name

def press(self, parent):
for i in range(len(parent.StyleButtons)):
parent.StyleButtons[i].setStyleSheet(parent.styleUnpressed)
self.setStyleSheet(parent.stylePressed)
parent.compute_model(self.model_name)

class TrainWindow(QDialog):
Expand Down Expand Up @@ -636,7 +639,7 @@ def tabletEvent(self, ev):
#print(ev.pressure())

def drawAt(self, pos, ev=None):
mask = self.greenmask
mask = self.strokemask
set = self.parent.current_point_set
stroke = self.parent.current_stroke
pos = [int(pos.y()), int(pos.x())]
Expand Down Expand Up @@ -691,7 +694,7 @@ def setDrawKernel(self, kernel_size=3):
offmask = np.zeros((bs,bs,1))
opamask = 100 * kernel[:,:,np.newaxis]
self.redmask = np.concatenate((onmask,offmask,offmask,onmask), axis=-1)
self.greenmask = np.concatenate((onmask,offmask,onmask,opamask), axis=-1)
self.strokemask = np.concatenate((onmask,offmask,onmask,opamask), axis=-1)


class RangeSlider(QSlider):
Expand Down
8 changes: 5 additions & 3 deletions docs/gui.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@ in 2D should be single strokes (if *single_stroke* is checked).

If you want to draw masks in 3D, then you can turn *single_stroke*
option off and draw a stroke on each plane with the cell and then press
ENTER (cellpose 1.0 only currently).
3D labelling will fill in unlabelled z-planes so that you do not
have to as densely label.
ENTER.

.. note::
3D labelling will fill in unlabelled z-planes so that you do not
have to densely label, for example you can skip some planes.


Segmentation options
Expand Down
7 changes: 5 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
from setuptools import setup

install_deps = ['numpy>=1.20.0', 'scipy', 'natsort',
'tifffile', 'tqdm', 'numba',
'tifffile', 'tqdm',
'numba>=0.53.0',
'llvmlite',
'torch>=1.6',
'opencv-python-headless',
'fastremap'
'fastremap',
'imagecodecs'
]

gui_deps = [
Expand Down

0 comments on commit 8d5f378

Please sign in to comment.