-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
112,989 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
Binary file not shown.
Binary file not shown.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
#!/usr/bin/env python3 | ||
# encoding: utf-8 | ||
import re | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.utils.model_zoo as model_zoo | ||
from collections import OrderedDict | ||
|
||
__all__ = ['DenseNet', 'Densenet121_AG'] | ||
|
||
|
||
model_urls = { | ||
'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', | ||
'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', | ||
'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', | ||
'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', | ||
} | ||
|
||
def Densenet121_AG(pretrained=False, **kwargs): | ||
r"""Densenet-121 model from | ||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_ | ||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
""" | ||
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), | ||
**kwargs) | ||
if pretrained: | ||
# '.'s are no longer allowed in module names, but pervious _DenseLayer | ||
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. | ||
# They are also in the checkpoints in model_urls. This pattern is used | ||
# to find such keys. | ||
pattern = re.compile( | ||
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') | ||
state_dict = model_zoo.load_url(model_urls['densenet121']) | ||
for key in list(state_dict.keys()): | ||
res = pattern.match(key) | ||
if res: | ||
new_key = res.group(1) + res.group(2) | ||
state_dict[new_key] = state_dict[key] | ||
del state_dict[key] | ||
model.load_state_dict(state_dict) | ||
return model | ||
|
||
|
||
class _DenseLayer(nn.Sequential): | ||
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): | ||
super(_DenseLayer, self).__init__() | ||
self.add_module('norm1', nn.BatchNorm2d(num_input_features)), | ||
self.add_module('relu1', nn.ReLU(inplace=True)), | ||
self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * | ||
growth_rate, kernel_size=1, stride=1, bias=False)), | ||
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), | ||
self.add_module('relu2', nn.ReLU(inplace=True)), | ||
self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, | ||
kernel_size=3, stride=1, padding=1, bias=False)), | ||
self.drop_rate = drop_rate | ||
|
||
def forward(self, x): | ||
new_features = super(_DenseLayer, self).forward(x) | ||
if self.drop_rate > 0: | ||
new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) | ||
return torch.cat([x, new_features], 1) | ||
|
||
|
||
class _DenseBlock(nn.Sequential): | ||
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): | ||
super(_DenseBlock, self).__init__() | ||
for i in range(num_layers): | ||
layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate) | ||
self.add_module('denselayer%d' % (i + 1), layer) | ||
|
||
|
||
class _Transition(nn.Sequential): | ||
def __init__(self, num_input_features, num_output_features): | ||
super(_Transition, self).__init__() | ||
self.add_module('norm', nn.BatchNorm2d(num_input_features)) | ||
self.add_module('relu', nn.ReLU(inplace=True)) | ||
self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, | ||
kernel_size=1, stride=1, bias=False)) | ||
self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) | ||
|
||
|
||
class DenseNet(nn.Module): | ||
r"""Densenet-BC model class, based on | ||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_ | ||
Args: | ||
growth_rate (int) - how many filters to add each layer (`k` in paper) | ||
block_config (list of 4 ints) - how many layers in each pooling block | ||
num_init_features (int) - the number of filters to learn in the first convolution layer | ||
bn_size (int) - multiplicative factor for number of bottle neck layers | ||
(i.e. bn_size * k features in the bottleneck layer) | ||
drop_rate (float) - dropout rate after each dense layer | ||
num_classes (int) - number of classification classes | ||
""" | ||
|
||
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), | ||
num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000): | ||
|
||
super(DenseNet, self).__init__() | ||
|
||
# First convolution | ||
self.features = nn.Sequential(OrderedDict([ | ||
('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), | ||
('norm0', nn.BatchNorm2d(num_init_features)), | ||
('relu0', nn.ReLU(inplace=True)), | ||
('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), | ||
])) | ||
|
||
# Each denseblock | ||
num_features = num_init_features | ||
for i, num_layers in enumerate(block_config): | ||
block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, | ||
bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) | ||
self.features.add_module('denseblock%d' % (i + 1), block) | ||
num_features = num_features + num_layers * growth_rate | ||
if i != len(block_config) - 1: | ||
trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) | ||
self.features.add_module('transition%d' % (i + 1), trans) | ||
num_features = num_features // 2 | ||
|
||
# Final batch norm | ||
self.features.add_module('norm5', nn.BatchNorm2d(num_features)) | ||
|
||
# Linear layer | ||
self.classifier = nn.Linear(num_features, num_classes) | ||
|
||
self.Sigmoid = nn.Sigmoid() | ||
|
||
# Official init from torch repo. | ||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
nn.init.kaiming_normal_(m.weight) | ||
elif isinstance(m, nn.BatchNorm2d): | ||
nn.init.constant_(m.weight, 1) | ||
nn.init.constant_(m.bias, 0) | ||
elif isinstance(m, nn.Linear): | ||
nn.init.constant_(m.bias, 0) | ||
|
||
def forward(self, x): | ||
features = self.features(x) | ||
out = F.relu(features, inplace=True) | ||
out_after_pooling = F.avg_pool2d(out, kernel_size=7, stride=1).view(features.size(0), -1) | ||
out = self.classifier(out_after_pooling) | ||
out = self.Sigmoid(out) | ||
return out, features, out_after_pooling | ||
|
||
|
||
class Fusion_Branch(nn.Module): | ||
def __init__(self, input_size, output_size): | ||
super(Fusion_Branch, self).__init__() | ||
self.fc = nn.Linear(input_size, output_size) | ||
self.Sigmoid = nn.Sigmoid() | ||
|
||
def forward(self, global_pool, local_pool): | ||
#fusion = torch.cat((global_pool.unsqueeze(2), local_pool.unsqueeze(2)), 2).cuda() | ||
#fusion = fusion.max(2)[0]#.squeeze(2).cuda() | ||
#print(fusion.shape) | ||
fusion = torch.cat((global_pool,local_pool), 1).cuda() | ||
fusion_var = torch.autograd.Variable(fusion) | ||
x = self.fc(fusion_var) | ||
x = self.Sigmoid(x) | ||
|
||
return x | ||
|
||
|
||
|
||
''' | ||
class DenseNet121(nn.Module): | ||
"""Model modified. | ||
The architecture of our model is the same as standard DenseNet121 | ||
except the classifier layer which has an additional sigmoid function. | ||
""" | ||
def __init__(self, out_size): | ||
super(DenseNet121, self).__init__() | ||
self.densenet121 = torchvision.models.densenet121(pretrained=True) | ||
num_ftrs = self.densenet121.classifier.in_features | ||
self.densenet121.classifier = nn.Sequential( | ||
nn.Linear(num_ftrs, out_size), | ||
nn.Sigmoid() | ||
) | ||
def forward(self, x): | ||
x = self.densenet121(x) | ||
return x | ||
''' | ||
|
||
#model = AG_CNN_densenet121(pretrained = True) | ||
#model.cuda() | ||
#input = torch.rand([1,3,224,224]).cuda() | ||
#output = model(input) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# encoding: utf-8 | ||
|
||
""" | ||
Read images and corresponding labels. | ||
""" | ||
|
||
import torch | ||
from torch.utils.data import Dataset | ||
from PIL import Image | ||
import os | ||
|
||
|
||
class ChestXrayDataSet(Dataset): | ||
def __init__(self, data_dir, image_list_file, transform=None): | ||
""" | ||
Args: | ||
data_dir: path to image directory. | ||
image_list_file: path to the file containing images | ||
with corresponding labels. | ||
transform: optional transform to be applied on a sample. | ||
""" | ||
image_names = [] | ||
labels = [] | ||
with open(image_list_file, "r") as f: | ||
for line in f: | ||
items = line.split() | ||
image_name= items[0] | ||
label = items[1:] | ||
label = [int(i) for i in label] | ||
image_name = os.path.join(data_dir, image_name) | ||
image_names.append(image_name) | ||
labels.append(label) | ||
|
||
self.image_names = image_names | ||
self.labels = labels | ||
self.transform = transform | ||
|
||
def __getitem__(self, index): | ||
""" | ||
Args: | ||
index: the index of item | ||
Returns: | ||
image and its labels | ||
""" | ||
image_name = self.image_names[index] | ||
image = Image.open(image_name).convert('RGB') | ||
label = self.labels[index] | ||
if self.transform is not None: | ||
image = self.transform(image) | ||
return image, torch.FloatTensor(label) | ||
|
||
def __len__(self): | ||
return len(self.image_names) | ||
|
Oops, something went wrong.