diff --git a/docs/source/index.rst b/docs/source/index.rst index 01b29ca1..6b26ecc6 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -54,6 +54,7 @@ Dana-Farber Cancer Institute. examples/link_stain_normalization examples/link_nucleus_detection examples/link_preprocessing_pipeline + examples/link_train_hovernet .. toctree:: :maxdepth: 2 diff --git a/docs/source/models.rst b/docs/source/models.rst index 80c8603d..3e1bffbd 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -3,20 +3,28 @@ Models ``PathML`` comes with several model architectures ready to use out of the box. -.. table:: - :widths: 20, 20, 60 +.. list-table:: Models included in PathML + :widths: 15, 70, 15 + :header-rows: 1 - ===================================== ============ ============= - Model Reference Description - ===================================== ============ ============= - U-net (in progress) [Unet]_ A model for segmentation in biomedical images - :class:`~pathml.ml.hovernet.HoVerNet` [HoVerNet]_ A model for nucleus segmentation and classification in H&E images - ===================================== ============ ============= + * - Model + - Description + - Reference + * - :class:`~pathml.ml.unet.UNet` + - A standard general-purpose model designed for segmentation in biomedical images. + Architecture consists of an 4 encoder blocks followed by 4 decoder blocks. + Skip connections propagate information from each layer of the encoder to the corresponding layer in + the decoder. + - [Unet]_ + * - :class:`~pathml.ml.hovernet.HoVerNet` + - A model for simultaneous nucleus segmentation and classification in H&E images. + Architecture consists of a single encoder with three separate decoder branches: one to perform binary + classification of nuclear pixels (NP), one to compute horizontal and vertical nucleus maps (HV), and one which + is used in the classification setting to perform classification of nuclear pixels (NC). + - [HoVerNet]_ You can also use models from `torchvision.models `_, or create your own! -In many cases, model parameters (weights) for pretrained networks may be available for use through the Model Repository. - References ---------- diff --git a/environment.yml b/environment.yml index d86bca7a..4a8384df 100644 --- a/environment.yml +++ b/environment.yml @@ -13,6 +13,7 @@ dependencies: - python-spams>=2.6.1 - pip>=20.0.2 - pytorch>=1.7.0 + - torchvision>=0.8.0 - openjdk==8.0.152 - h5py>=3.1.0 - pip: diff --git a/pathml/ml/unet.py b/pathml/ml/unet.py new file mode 100644 index 00000000..58de7229 --- /dev/null +++ b/pathml/ml/unet.py @@ -0,0 +1,124 @@ +import torch +from torch import nn +from torchvision.transforms import CenterCrop +from torch.nn.functional import interpolate + + +class _UNetConvBlock(nn.Module): + """ + Convolution block for U-Net + + From the paper: + The contracting path follows the typical architecture of a convolutional network. + It consists of the repeated application of two 3x3 convolutions (unpadded convolutions), each followed + by a rectified linear unit (ReLU)... + """ + def __init__(self, in_c, out_c): + super().__init__() + self.conv1 = nn.Conv2d(in_channels = in_c, out_channels = out_c, kernel_size = 3) + self.relu = nn.ReLU() + self.conv2 = nn.Conv2d(in_channels = out_c, out_channels = out_c, kernel_size = 3) + + def forward(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.relu(x) + return x + + +class _UNetUpConvBlock(nn.Module): + """ + Up-Convolution block for U-Net + + From the paper: + Every step in the expansive path consists of an upsampling of the feature map followed by a 2x2 convolution + (“up-convolution”) that halves the number of feature channels, a concatenation with the correspondingly + cropped feature map from the contracting path, and two 3x3 convolutions, each followed by a ReLU. + """ + def __init__(self, in_c, out_c): + super().__init__() + self.up = nn.ConvTranspose2d(in_channels = in_c, out_channels = out_c, kernel_size = 2, stride = 2) + self.conv = _UNetConvBlock(in_c = in_c, out_c = out_c) + + def forward(self, x, x_skip): + """ + x is the input + x_skip is the skip connection from the encoder block + """ + x = self.up(x) + # crop tensor from skip connection to match H and W of x + x_skip = CenterCrop((x.shape[2], x.shape[3]))(x_skip) + x = torch.cat([x, x_skip], dim = 1) + x = self.conv(x) + return x + + +class UNet(nn.Module): + """ + U-Net is a convolutional network for biomedical image segmentation. + The architecture consists of a contracting path to capture context and a symmetric expanding + path that enables precise localization. + + As described in the original paper, by default no padding is used, so the dimensions get smaller each layer. + Input of size 572px will lead to output of size 388px (See Fig. 1 in the paper). + The ``keep_dim`` parameter can be used to enfore the output to be the same shape as the input. + + Code is based on: + https://amaarora.github.io/2020/09/13/unet.html + https://github.com/LeeJunHyun/Image_Segmentation + + Args: + in_channels (int): Number of channels in input. E.g. 3 for RGB image + out_channels (int): Number of channels in output. E.g. 1 for a binary classification setting. + keep_dim (bool): Whether to enforce output to match the dimensions of input. If ``True``, a final interpolation + step will be applied. Defaults to ``False``. + + References: + Ronneberger, O., Fischer, P. and Brox, T., 2015, October. + U-net: Convolutional networks for biomedical image segmentation. + In International Conference on Medical image computing and computer-assisted intervention (pp. 234-241). + Springer, Cham. + """ + + def __init__(self, in_channels=3, out_channels=1, keep_dim=False): + super().__init__() + self.keep_dim = keep_dim + self.pool = nn.MaxPool2d(2) + + self.conv1 = _UNetConvBlock(in_c = in_channels, out_c = 64) + self.conv2 = _UNetConvBlock(in_c = 64, out_c = 128) + self.conv3 = _UNetConvBlock(in_c = 128, out_c = 256) + self.conv4 = _UNetConvBlock(in_c = 256, out_c = 512) + self.conv5 = _UNetConvBlock(in_c = 512, out_c = 1024) + + self.upconv1 = _UNetUpConvBlock(in_c = 1024, out_c = 512) + self.upconv2 = _UNetUpConvBlock(in_c = 512, out_c = 256) + self.upconv3 = _UNetUpConvBlock(in_c = 256, out_c = 128) + self.upconv4 = _UNetUpConvBlock(in_c = 128, out_c = 64) + + self.head = nn.Conv2d(in_channels = 64, out_channels = out_channels, kernel_size = 1) + + def forward(self, x): + # encoder + x1 = self.conv1(x) + x2 = self.pool(x1) + x2 = self.conv2(x2) + x3 = self.pool(x2) + x3 = self.conv3(x3) + x4 = self.pool(x3) + x4 = self.conv4(x4) + x5 = self.pool(x4) + x5 = self.conv5(x5) + + # decoder + up1 = self.upconv1(x = x5, x_skip = x4) + up2 = self.upconv2(x = up1, x_skip = x3) + up3 = self.upconv3(x = up2, x_skip = x2) + up4 = self.upconv4(x = up3, x_skip = x1) + out = self.head(up4) + + if self.keep_dim: + out = interpolate(out, size = (x.shape[2], x.shape[3])) + + return out diff --git a/tests/ml_tests/test_unet.py b/tests/ml_tests/test_unet.py new file mode 100644 index 00000000..c435249f --- /dev/null +++ b/tests/ml_tests/test_unet.py @@ -0,0 +1,35 @@ +import pytest +import torch + +from pathml.ml.unet import UNet + + +@pytest.mark.parametrize("keepdim", [True, False]) +@pytest.mark.parametrize("out_c", [1, 3]) +def test_unet_shapes(out_c, keepdim): + batch_size = 1 + channels_in = 3 + im_size_in = 572 + + x = torch.randn(batch_size, channels_in, im_size_in, im_size_in) + + net = UNet(out_channels = out_c, keep_dim = keepdim) + out = net(x) + + if keepdim: + assert out.shape == (batch_size, out_c, im_size_in, im_size_in) + else: + + # compute output size, if keep_dim is false + nlayers = 4 + im_size_out = im_size_in + # at each layer in encoder, two conv layers without padding loses 4px total, followed by a downsample by 2 + for _ in range(nlayers): + im_size_out = (im_size_out - 4) / 2 + # two more conv layers at the bottom layer + im_size_out = im_size_out - 4 + # at each layer in decoder, upsample by 2 followed by two conv layers without padding for a loss of 4 + for _ in range(nlayers): + im_size_out = (im_size_out * 2) - 4 + + assert out.shape == (batch_size, out_c, im_size_out, im_size_out) \ No newline at end of file