forked from BalintHompot/uncertainty
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added MCDropout_AuxOut implementation
- Loading branch information
Jeroen Vranken
committed
Feb 23, 2021
1 parent
17e2bbd
commit 308d04d
Showing
5 changed files
with
371 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import time | ||
import sys | ||
import datetime | ||
import math | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
import torchnet as tnt | ||
from torchnet.engine import Engine | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from network.ecgresnet import BasicBlock, Flatten | ||
from utils.helpers import convert_predictions_to_expert_categories, convert_variances_to_expert_categories | ||
|
||
class ECGResNet_MCDropout_AuxOutput(nn.Module): | ||
""" | ||
This class implements the ECG-ResNet in PyTorch. | ||
It handles the different layers and parameters of the model. | ||
Once initialized an ResNet object can perform forward. | ||
""" | ||
def __init__(self, in_length, in_channels, n_grps, N, num_classes, dropout, first_width, | ||
stride, dilation, n_dropout_samples, sampling_dropout_rate, n_logit_samples, train=False): | ||
""" | ||
Initializes ECGResNet object. | ||
Args: | ||
in_channels: number of channels of input | ||
n_grps: number of ResNet groups | ||
N: number of blocks per groups | ||
num_classes: number of classes of the classification problem | ||
stride: tuple with stride value per block per group | ||
""" | ||
super().__init__() | ||
num_branches = 2 | ||
first_width = first_width * num_branches | ||
stem = [nn.Conv1d(in_channels, first_width // 2, kernel_size=7, padding=3, | ||
stride = 2, dilation = 1, bias=False), | ||
nn.BatchNorm1d(first_width // 2), nn.ReLU(), | ||
nn.Conv1d(first_width // 2, first_width, kernel_size = 1, | ||
padding = 0, stride = 1, bias = False), | ||
nn.BatchNorm1d(first_width), nn.ReLU(), nn.Dropout(dropout), | ||
nn.Conv1d(first_width, first_width, kernel_size = 5, | ||
padding = 2, stride = 1, bias = False)] | ||
|
||
layers = [] | ||
|
||
# Double feature depth at each group, after the first | ||
widths = [first_width] | ||
for grp in range(n_grps): | ||
widths.append((first_width)*2**grp) | ||
for grp in range(n_grps): | ||
layers += self._make_group(N, widths[grp], widths[grp+1], | ||
stride, dropout, dilation, num_branches) | ||
|
||
layers += [nn.BatchNorm1d(widths[-1]), nn.ReLU(inplace=True)] | ||
fclayers1 = [nn.Linear(20096, 256), nn.ReLU(inplace = True), | ||
nn.Dropout(dropout), nn.Linear(256, num_classes)] | ||
fclayers2 = [nn.Linear(5120, 256), nn.ReLU(inplace = True), | ||
nn.Dropout(dropout), nn.Linear(256, 2*num_classes)] | ||
|
||
self.stem = nn.Sequential(*stem) | ||
aux_point = (len(layers) - 2) // 2 | ||
self.features1 = nn.Sequential(*layers[:aux_point]) | ||
self.features2 = nn.Sequential(*layers[aux_point:]) | ||
self.flatten = Flatten() | ||
self.fc1 = nn.Sequential(*fclayers1) | ||
self.fc2 = nn.Sequential(*fclayers2) | ||
self.num_classes = num_classes | ||
self.n_dropout_samples = n_dropout_samples | ||
self.sampling_dropout_rate = sampling_dropout_rate # Dropout during MCDropout sampling | ||
self.n_logit_samples = n_logit_samples # Number of logit samples of the auxiliary output | ||
self.Gauss = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(num_classes), torch.eye(num_classes)) | ||
|
||
def _make_group(self, N, in_channels, out_channels, stride, dropout, dilation, num_branches): | ||
""" | ||
Builds a group of blocks. | ||
Args: | ||
in_channels: number of channels of input | ||
out_channels: number of channels of output | ||
stride: stride of convolutions | ||
N: number of blocks per groups | ||
num_classes: number of classes of the classification problem | ||
""" | ||
group = list() | ||
for i in range(N): | ||
blk = BasicBlock(in_channels=(in_channels if i == 0 else out_channels), | ||
out_channels=out_channels, stride=stride[i], | ||
dropout = dropout, dilation = dilation, | ||
num_branches = num_branches) | ||
group.append(blk) | ||
return group | ||
|
||
def forward(self, x): | ||
x = self.stem(x) | ||
x1 = self.features1(x) | ||
x1out = self.flatten(x1) | ||
x2 = self.features2(x1) | ||
x2out = self.flatten(x2) | ||
|
||
# Get logits | ||
logits = self.fc2(x2out) | ||
|
||
# Split into mean and log_variance | ||
output2_mean = logits[:, 0:self.num_classes] | ||
output2_log_var = logits[:, self.num_classes:] | ||
return self.fc1(x1out), output2_mean, output2_log_var | ||
|
||
|
||
# Turn on the dropout layers | ||
def enable_dropout(self): | ||
for module in self.modules(): | ||
if module.__class__.__name__.startswith('Dropout'): | ||
# Turn on dropout | ||
module.train() | ||
|
||
# Set dropout rate | ||
module.p = self.sampling_dropout_rate | ||
|
||
# Takes n Monte Carlo samples | ||
def mc_sample_with_sample_logits(self, data): | ||
predictions = torch.empty((data.shape[0], self.n_dropout_samples, self.num_classes)) | ||
predictions_no_sm = torch.empty((data.shape[0], self.n_dropout_samples, self.num_classes)) | ||
log_variances = torch.empty((data.shape[0], self.n_dropout_samples, self.num_classes)) | ||
|
||
for i in range(self.n_dropout_samples): | ||
# forward push | ||
_, output2_mean, output2_log_var = self(data) | ||
|
||
# Sample from logits, returning a vector x_i | ||
x_i = self.sample_logits(self.n_logit_samples, output2_mean, output2_log_var, average=True) | ||
|
||
# Apply softmax to obtain probability vector p_i | ||
p_i = F.softmax(x_i, dim=1) | ||
|
||
# Save results | ||
predictions[:, i] = p_i | ||
predictions_no_sm[:, i] = x_i | ||
log_variances[:, i] = output2_log_var | ||
|
||
# Calculate mean and variance over the predictions, mean over log_variances, return results | ||
predictions_mean = predictions.mean(dim=1) | ||
predictions_mean_no_sm = predictions_no_sm.mean(dim=1) | ||
predictions_var = predictions.var(dim=1) | ||
log_variances_mean = log_variances.mean(dim=1) | ||
|
||
return predictions, predictions_mean, predictions_var, log_variances_mean, predictions_mean_no_sm | ||
|
||
# Takes T samples from the logits, by corrupting the network output with | ||
# Gaussian noise with variance determined by the networks auxiliary | ||
# outputs. | ||
# As in "What uncertainties do we need in Bayesian deep learning for | ||
# computer vision?", equation (12), first part. | ||
# "In practice, we train the network to predict the log variance!" | ||
def sample_logits(self, T, input, log_var, average=True): | ||
|
||
# Take the exponent to get the variance | ||
variance = log_var.exp() | ||
|
||
# Go from shape: [batch x num_classes] -> [batch x T x num_classes] | ||
sigma = variance[:, None, :].repeat(1, T, 1) | ||
f = input[:, None, :].repeat(1, T, 1) | ||
|
||
# Take T samples from the Gaussian distribution | ||
epsilon = self.Gauss.sample([input.shape[0], T]) | ||
|
||
# Multiply Gaussian noise with variance, and add to the prediction | ||
x_i = f + (sigma * epsilon) | ||
|
||
if average==True: | ||
return x_i.mean(dim=1) | ||
else: | ||
return x_i | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
import sys | ||
import os | ||
import torch | ||
import pandas as pd | ||
import datetime | ||
from argparse import ArgumentParser | ||
from torch import nn, optim | ||
import torch.nn.functional as F | ||
from torch.utils.data import DataLoader, random_split | ||
|
||
import pytorch_lightning as pl | ||
from pytorch_lightning.metrics import functional as FM | ||
|
||
from network.ecgresnet_mcdropout_auxout import ECGResNet_MCDropout_AuxOutput | ||
from utils.helpers import create_results_directory | ||
from utils.focalloss_weights import FocalLoss | ||
|
||
class ECGResNetMCDropout_AuxOutSystem(pl.LightningModule): | ||
|
||
def __init__(self, in_length, in_channels, n_grps, N, | ||
num_classes, dropout, first_width, stride, | ||
dilation, learning_rate, n_dropout_samples, n_logit_samples, sampling_dropout_rate, loss_weights=None, | ||
**kwargs): | ||
super().__init__() | ||
self.save_hyperparameters() | ||
self.learning_rate = learning_rate | ||
self.n_dropout_samples = n_dropout_samples | ||
self.n_logit_samples = n_logit_samples | ||
|
||
self.IDs = torch.empty(0).type(torch.LongTensor) | ||
self.predicted_labels = torch.empty(0).type(torch.LongTensor) | ||
self.correct_predictions = torch.empty(0).type(torch.BoolTensor) | ||
self.aleatoric_uncertainty = torch.empty(0).type(torch.FloatTensor) | ||
self.epistemic_uncertainty = torch.empty(0).type(torch.FloatTensor) | ||
self.total_uncertainty = torch.empty(0).type(torch.FloatTensor) | ||
|
||
self.model = ECGResNet_MCDropout_AuxOutput(in_length, in_channels, | ||
n_grps, N, num_classes, | ||
dropout, first_width, | ||
stride, dilation, n_dropout_samples, sampling_dropout_rate, n_logit_samples) | ||
|
||
|
||
if loss_weights is not None: | ||
weights = torch.tensor(loss_weights, dtype = torch.float) | ||
else: | ||
weights = loss_weights | ||
|
||
self.loss = FocalLoss(gamma=1, weights = weights) | ||
|
||
def forward(self, x): | ||
output1, output2_mean, output2_log_var = self.model(x) | ||
return output1, output2_mean, output2_log_var | ||
|
||
def training_step(self, batch, batch_idx): | ||
"""Performs a training step. | ||
Args: | ||
batch (dict): Output of the dataloader. | ||
batch_idx (int): Index no. of this batch. | ||
Returns: | ||
tensor: Total loss for this step. | ||
""" | ||
data, target = batch['waveform'], batch['label'] | ||
|
||
# Make prediction | ||
output1, output2_mean, output2_log_var = self(data) | ||
|
||
# Sample from logits, returning a vector x_i | ||
x_i = self.model.sample_logits(self.n_logit_samples, output2_mean, output2_log_var, average=True) | ||
|
||
# Apply softmax to obtain probability vector p_i | ||
# p_i = F.softmax(x_i, dim=1) | ||
|
||
train_loss1 = self.loss(output1, target) | ||
train_loss2 = self.loss(x_i, target) | ||
total_train_loss = (0.3 * train_loss1) + train_loss2 | ||
|
||
self.log('train_loss', total_train_loss) | ||
|
||
return {'loss': total_train_loss} | ||
|
||
def validation_step(self, batch, batch_idx): | ||
data, target = batch['waveform'], batch['label'] | ||
|
||
# Make prediction | ||
_, output2_mean, output2_log_var = self(data) | ||
|
||
# Sample from logits, returning a vector x_i | ||
x_i = self.model.sample_logits(self.n_logit_samples, output2_mean, output2_log_var, average=True) | ||
|
||
# Apply softmax to obtain probability vector p_i | ||
p_i = F.softmax(x_i, dim=1) | ||
|
||
val_loss = self.loss(x_i, target) | ||
acc = FM.accuracy(p_i, target) | ||
|
||
# loss is tensor. The Checkpoint Callback is monitoring 'checkpoint_on' | ||
metrics = {'val_loss': val_loss.item(), 'val_acc': acc.item()} | ||
self.log('val_acc', acc.item()) | ||
self.log('val_loss', val_loss.item()) | ||
return metrics | ||
|
||
def on_test_epoch_start(self): | ||
# Enable dropout at test time. | ||
self.model.enable_dropout() | ||
|
||
def test_step(self, batch, batch_idx, save_to_csv=False): | ||
data, target = batch['waveform'], batch['label'] | ||
|
||
# MC sample using dropout, sample logits for every mc-dropout sample | ||
predictions, predictions_mean, predictions_var, log_variances_mean, predictions_mean_no_sm = self.model.mc_sample_with_sample_logits(data) | ||
|
||
# Take exponent to get the variance | ||
output2_var = log_variances_mean.exp() | ||
|
||
predicted_labels = predictions_mean.argmax(dim=1) | ||
correct_predictions = torch.eq(predicted_labels, target) | ||
|
||
# MC dropout variance over predicted labels (epistemic uncertainty) | ||
sampled_var = torch.gather(predictions_var, 1, predictions_mean.argmax(dim=1).unsqueeze_(1))[:, 0] | ||
|
||
# Predicted aux-out variance of the predicted label (aleatoric uncertainty) | ||
predicted_labels_predicted_var = torch.gather(output2_var, 1, predictions_mean.argmax(dim=1).unsqueeze_(1))[:, 0] | ||
|
||
# Total uncertainty | ||
total_var = predicted_labels_predicted_var + sampled_var | ||
|
||
# Get metrics | ||
test_loss = self.loss(predictions_mean, target) | ||
acc = FM.accuracy(predictions_mean, target) | ||
|
||
self.log('test_acc', acc.item()) | ||
self.log('test_loss', test_loss.item()) | ||
|
||
self.IDs = torch.cat((self.IDs, batch['id']), 0) | ||
self.predicted_labels = torch.cat((self.predicted_labels, predicted_labels), 0) | ||
self.correct_predictions = torch.cat((self.correct_predictions, correct_predictions), 0) | ||
self.aleatoric_uncertainty = torch.cat((self.aleatoric_uncertainty, predicted_labels_predicted_var), 0) | ||
self.epistemic_uncertainty = torch.cat((self.epistemic_uncertainty, sampled_var), 0) | ||
self.total_uncertainty = torch.cat((self.total_uncertainty, total_var), 0) | ||
|
||
return {'test_loss': test_loss.item(), 'test_acc': acc.item()} | ||
|
||
# Initialize optimizer | ||
def configure_optimizers(self): | ||
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate) | ||
|
||
return optimizer | ||
|
||
def add_model_specific_args(parent_parser): | ||
parser = ArgumentParser(parents=[parent_parser], add_help=False) | ||
parser.add_argument('--model_name', type=str, default='none_auxout') | ||
parser.add_argument('--n_logit_samples', type=int, default=100) # Number of logit samples of the auxiliary output | ||
parser.add_argument('--n_dropout_samples', type=int, default=20) # Number of dropout samples during MCDropout sampling | ||
parser.add_argument('--sampling_dropout_rate', type=float, default=0.1) # Dropout rate during MCDropout sampling | ||
parser.add_argument('--ensembling_method', type=bool, default=False) | ||
return parser | ||
|
||
# Combine results into single dataframe and save to disk | ||
def save_results(self): | ||
results = pd.concat([ | ||
pd.DataFrame(self.IDs.numpy(), columns= ['ID']), | ||
pd.DataFrame(self.predicted_labels.numpy(), columns= ['predicted_label']), | ||
pd.DataFrame(self.correct_predictions.numpy(), columns= ['correct_prediction']), | ||
pd.DataFrame(self.aleatoric_uncertainty.numpy(), columns= ['aleatoric_uncertainty']), | ||
pd.DataFrame(self.epistemic_uncertainty.numpy(), columns= ['epistemic_uncertainty']), | ||
pd.DataFrame(self.total_uncertainty.numpy(), columns= ['total_uncertainty']), | ||
], axis=1) | ||
|
||
create_results_directory() | ||
results.to_csv('results/{}_{}_results.csv'.format(self.__class__.__name__, datetime.datetime.now().replace(microsecond=0).isoformat()), index=False) |
Oops, something went wrong.