Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unet #95

Open
wants to merge 4 commits into
base: dev
Choose a base branch
from
Open

Unet #95

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Dana-Farber Cancer Institute.
examples/link_stain_normalization
examples/link_nucleus_detection
examples/link_preprocessing_pipeline
examples/link_train_hovernet

.. toctree::
:maxdepth: 2
Expand Down
28 changes: 18 additions & 10 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,28 @@ Models

``PathML`` comes with several model architectures ready to use out of the box.

.. table::
:widths: 20, 20, 60
.. list-table:: Models included in PathML
:widths: 15, 70, 15
:header-rows: 1

===================================== ============ =============
Model Reference Description
===================================== ============ =============
U-net (in progress) [Unet]_ A model for segmentation in biomedical images
:class:`~pathml.ml.hovernet.HoVerNet` [HoVerNet]_ A model for nucleus segmentation and classification in H&E images
===================================== ============ =============
* - Model
- Description
- Reference
* - :class:`~pathml.ml.unet.UNet`
- A standard general-purpose model designed for segmentation in biomedical images.
Architecture consists of an 4 encoder blocks followed by 4 decoder blocks.
Skip connections propagate information from each layer of the encoder to the corresponding layer in
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

layer of same dimension

the decoder.
- [Unet]_
* - :class:`~pathml.ml.hovernet.HoVerNet`
- A model for simultaneous nucleus segmentation and classification in H&E images.
Architecture consists of a single encoder with three separate decoder branches: one to perform binary
classification of nuclear pixels (NP), one to compute horizontal and vertical nucleus maps (HV), and one which
is used in the classification setting to perform classification of nuclear pixels (NC).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

verbose

- [HoVerNet]_

You can also use models from `torchvision.models <https://pytorch.org/docs/stable/torchvision/models.html>`_, or create your own!

In many cases, model parameters (weights) for pretrained networks may be available for use through the Model Repository.

References
----------

Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies:
- python-spams>=2.6.1
- pip>=20.0.2
- pytorch>=1.7.0
- torchvision>=0.8.0
- openjdk==8.0.152
- h5py>=3.1.0
- pip:
Expand Down
124 changes: 124 additions & 0 deletions pathml/ml/unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import torch
from torch import nn
from torchvision.transforms import CenterCrop
from torch.nn.functional import interpolate


class _UNetConvBlock(nn.Module):
"""
Convolution block for U-Net

From the paper:
The contracting path follows the typical architecture of a convolutional network.
It consists of the repeated application of two 3x3 convolutions (unpadded convolutions), each followed
by a rectified linear unit (ReLU)...
"""
def __init__(self, in_c, out_c):
super().__init__()
self.conv1 = nn.Conv2d(in_channels = in_c, out_channels = out_c, kernel_size = 3)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels = out_c, out_channels = out_c, kernel_size = 3)

def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.relu(x)
return x


class _UNetUpConvBlock(nn.Module):
"""
Up-Convolution block for U-Net

From the paper:
Every step in the expansive path consists of an upsampling of the feature map followed by a 2x2 convolution
(“up-convolution”) that halves the number of feature channels, a concatenation with the correspondingly
cropped feature map from the contracting path, and two 3x3 convolutions, each followed by a ReLU.
"""
def __init__(self, in_c, out_c):
super().__init__()
self.up = nn.ConvTranspose2d(in_channels = in_c, out_channels = out_c, kernel_size = 2, stride = 2)
self.conv = _UNetConvBlock(in_c = in_c, out_c = out_c)

def forward(self, x, x_skip):
"""
x is the input
x_skip is the skip connection from the encoder block
"""
x = self.up(x)
# crop tensor from skip connection to match H and W of x
x_skip = CenterCrop((x.shape[2], x.shape[3]))(x_skip)
x = torch.cat([x, x_skip], dim = 1)
x = self.conv(x)
return x


class UNet(nn.Module):
"""
U-Net is a convolutional network for biomedical image segmentation.
The architecture consists of a contracting path to capture context and a symmetric expanding
path that enables precise localization.

As described in the original paper, by default no padding is used, so the dimensions get smaller each layer.
Input of size 572px will lead to output of size 388px (See Fig. 1 in the paper).
The ``keep_dim`` parameter can be used to enfore the output to be the same shape as the input.

Code is based on:
https://amaarora.github.io/2020/09/13/unet.html
https://github.com/LeeJunHyun/Image_Segmentation

Args:
in_channels (int): Number of channels in input. E.g. 3 for RGB image
out_channels (int): Number of channels in output. E.g. 1 for a binary classification setting.
keep_dim (bool): Whether to enforce output to match the dimensions of input. If ``True``, a final interpolation
step will be applied. Defaults to ``False``.

References:
Ronneberger, O., Fischer, P. and Brox, T., 2015, October.
U-net: Convolutional networks for biomedical image segmentation.
In International Conference on Medical image computing and computer-assisted intervention (pp. 234-241).
Springer, Cham.
"""

def __init__(self, in_channels=3, out_channels=1, keep_dim=False):
super().__init__()
self.keep_dim = keep_dim
self.pool = nn.MaxPool2d(2)

self.conv1 = _UNetConvBlock(in_c = in_channels, out_c = 64)
self.conv2 = _UNetConvBlock(in_c = 64, out_c = 128)
self.conv3 = _UNetConvBlock(in_c = 128, out_c = 256)
self.conv4 = _UNetConvBlock(in_c = 256, out_c = 512)
self.conv5 = _UNetConvBlock(in_c = 512, out_c = 1024)

self.upconv1 = _UNetUpConvBlock(in_c = 1024, out_c = 512)
self.upconv2 = _UNetUpConvBlock(in_c = 512, out_c = 256)
self.upconv3 = _UNetUpConvBlock(in_c = 256, out_c = 128)
self.upconv4 = _UNetUpConvBlock(in_c = 128, out_c = 64)

self.head = nn.Conv2d(in_channels = 64, out_channels = out_channels, kernel_size = 1)

def forward(self, x):
# encoder
x1 = self.conv1(x)
x2 = self.pool(x1)
x2 = self.conv2(x2)
x3 = self.pool(x2)
x3 = self.conv3(x3)
x4 = self.pool(x3)
x4 = self.conv4(x4)
x5 = self.pool(x4)
x5 = self.conv5(x5)

# decoder
up1 = self.upconv1(x = x5, x_skip = x4)
up2 = self.upconv2(x = up1, x_skip = x3)
up3 = self.upconv3(x = up2, x_skip = x2)
up4 = self.upconv4(x = up3, x_skip = x1)
out = self.head(up4)

if self.keep_dim:
out = interpolate(out, size = (x.shape[2], x.shape[3]))

return out
35 changes: 35 additions & 0 deletions tests/ml_tests/test_unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pytest
import torch

from pathml.ml.unet import UNet


@pytest.mark.parametrize("keepdim", [True, False])
@pytest.mark.parametrize("out_c", [1, 3])
def test_unet_shapes(out_c, keepdim):
batch_size = 1
channels_in = 3
im_size_in = 572

x = torch.randn(batch_size, channels_in, im_size_in, im_size_in)

net = UNet(out_channels = out_c, keep_dim = keepdim)
out = net(x)

if keepdim:
assert out.shape == (batch_size, out_c, im_size_in, im_size_in)
else:

# compute output size, if keep_dim is false
nlayers = 4
im_size_out = im_size_in
# at each layer in encoder, two conv layers without padding loses 4px total, followed by a downsample by 2
for _ in range(nlayers):
im_size_out = (im_size_out - 4) / 2
# two more conv layers at the bottom layer
im_size_out = im_size_out - 4
# at each layer in decoder, upsample by 2 followed by two conv layers without padding for a loss of 4
for _ in range(nlayers):
im_size_out = (im_size_out * 2) - 4

assert out.shape == (batch_size, out_c, im_size_out, im_size_out)