Skip to content

Commit

Permalink
Added MCDropout_AuxOut implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeroen Vranken committed Feb 23, 2021
1 parent 17e2bbd commit 308d04d
Show file tree
Hide file tree
Showing 5 changed files with 371 additions and 3 deletions.
2 changes: 1 addition & 1 deletion project/ecgresnet_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@
"test_labels_csv": "/training/template/datasets/template_test.csv",
"learning_rate": 0.0005,
"batch_size": 16,
"max_epochs": 2
"max_epochs": 1
}

3 changes: 2 additions & 1 deletion project/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from systems.ecgresnet_varinf import ECGResNetVariationalInferenceSystem
from systems.ecgresnet_ensemble_auxout import ECGResNetEnsemble_AuxOutSystem
from systems.ecgresnet_ssensemble_auxout import ECGResNetSnapshotEnsemble_AuxOutSystem
from systems.ecgresnet_mcdropout_auxout import ECGResNetMCDropout_AuxOutSystem
from utils.dataloader import CPSC2018Dataset
from utils.transforms import ToTensor, Resample
from utils.transforms import ApplyGain
Expand Down Expand Up @@ -139,7 +140,7 @@ def get_model_class(args):

elif temp_args.epistemic_method == 'mcdropout':
# mcdropout_auxout
return ECGResNetAuxOutput_MCDropoutSystem
return ECGResNetMCDropout_AuxOutSystem

elif temp_args.epistemic_method == 'ssensemble':
# ssensemble_auxout
Expand Down
181 changes: 181 additions & 0 deletions project/network/ecgresnet_mcdropout_auxout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import time
import sys
import datetime
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchnet as tnt
from torchnet.engine import Engine

import numpy as np
import pandas as pd

from network.ecgresnet import BasicBlock, Flatten
from utils.helpers import convert_predictions_to_expert_categories, convert_variances_to_expert_categories

class ECGResNet_MCDropout_AuxOutput(nn.Module):
"""
This class implements the ECG-ResNet in PyTorch.
It handles the different layers and parameters of the model.
Once initialized an ResNet object can perform forward.
"""
def __init__(self, in_length, in_channels, n_grps, N, num_classes, dropout, first_width,
stride, dilation, n_dropout_samples, sampling_dropout_rate, n_logit_samples, train=False):
"""
Initializes ECGResNet object.
Args:
in_channels: number of channels of input
n_grps: number of ResNet groups
N: number of blocks per groups
num_classes: number of classes of the classification problem
stride: tuple with stride value per block per group
"""
super().__init__()
num_branches = 2
first_width = first_width * num_branches
stem = [nn.Conv1d(in_channels, first_width // 2, kernel_size=7, padding=3,
stride = 2, dilation = 1, bias=False),
nn.BatchNorm1d(first_width // 2), nn.ReLU(),
nn.Conv1d(first_width // 2, first_width, kernel_size = 1,
padding = 0, stride = 1, bias = False),
nn.BatchNorm1d(first_width), nn.ReLU(), nn.Dropout(dropout),
nn.Conv1d(first_width, first_width, kernel_size = 5,
padding = 2, stride = 1, bias = False)]

layers = []

# Double feature depth at each group, after the first
widths = [first_width]
for grp in range(n_grps):
widths.append((first_width)*2**grp)
for grp in range(n_grps):
layers += self._make_group(N, widths[grp], widths[grp+1],
stride, dropout, dilation, num_branches)

layers += [nn.BatchNorm1d(widths[-1]), nn.ReLU(inplace=True)]
fclayers1 = [nn.Linear(20096, 256), nn.ReLU(inplace = True),
nn.Dropout(dropout), nn.Linear(256, num_classes)]
fclayers2 = [nn.Linear(5120, 256), nn.ReLU(inplace = True),
nn.Dropout(dropout), nn.Linear(256, 2*num_classes)]

self.stem = nn.Sequential(*stem)
aux_point = (len(layers) - 2) // 2
self.features1 = nn.Sequential(*layers[:aux_point])
self.features2 = nn.Sequential(*layers[aux_point:])
self.flatten = Flatten()
self.fc1 = nn.Sequential(*fclayers1)
self.fc2 = nn.Sequential(*fclayers2)
self.num_classes = num_classes
self.n_dropout_samples = n_dropout_samples
self.sampling_dropout_rate = sampling_dropout_rate # Dropout during MCDropout sampling
self.n_logit_samples = n_logit_samples # Number of logit samples of the auxiliary output
self.Gauss = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(num_classes), torch.eye(num_classes))

def _make_group(self, N, in_channels, out_channels, stride, dropout, dilation, num_branches):
"""
Builds a group of blocks.
Args:
in_channels: number of channels of input
out_channels: number of channels of output
stride: stride of convolutions
N: number of blocks per groups
num_classes: number of classes of the classification problem
"""
group = list()
for i in range(N):
blk = BasicBlock(in_channels=(in_channels if i == 0 else out_channels),
out_channels=out_channels, stride=stride[i],
dropout = dropout, dilation = dilation,
num_branches = num_branches)
group.append(blk)
return group

def forward(self, x):
x = self.stem(x)
x1 = self.features1(x)
x1out = self.flatten(x1)
x2 = self.features2(x1)
x2out = self.flatten(x2)

# Get logits
logits = self.fc2(x2out)

# Split into mean and log_variance
output2_mean = logits[:, 0:self.num_classes]
output2_log_var = logits[:, self.num_classes:]
return self.fc1(x1out), output2_mean, output2_log_var


# Turn on the dropout layers
def enable_dropout(self):
for module in self.modules():
if module.__class__.__name__.startswith('Dropout'):
# Turn on dropout
module.train()

# Set dropout rate
module.p = self.sampling_dropout_rate

# Takes n Monte Carlo samples
def mc_sample_with_sample_logits(self, data):
predictions = torch.empty((data.shape[0], self.n_dropout_samples, self.num_classes))
predictions_no_sm = torch.empty((data.shape[0], self.n_dropout_samples, self.num_classes))
log_variances = torch.empty((data.shape[0], self.n_dropout_samples, self.num_classes))

for i in range(self.n_dropout_samples):
# forward push
_, output2_mean, output2_log_var = self(data)

# Sample from logits, returning a vector x_i
x_i = self.sample_logits(self.n_logit_samples, output2_mean, output2_log_var, average=True)

# Apply softmax to obtain probability vector p_i
p_i = F.softmax(x_i, dim=1)

# Save results
predictions[:, i] = p_i
predictions_no_sm[:, i] = x_i
log_variances[:, i] = output2_log_var

# Calculate mean and variance over the predictions, mean over log_variances, return results
predictions_mean = predictions.mean(dim=1)
predictions_mean_no_sm = predictions_no_sm.mean(dim=1)
predictions_var = predictions.var(dim=1)
log_variances_mean = log_variances.mean(dim=1)

return predictions, predictions_mean, predictions_var, log_variances_mean, predictions_mean_no_sm

# Takes T samples from the logits, by corrupting the network output with
# Gaussian noise with variance determined by the networks auxiliary
# outputs.
# As in "What uncertainties do we need in Bayesian deep learning for
# computer vision?", equation (12), first part.
# "In practice, we train the network to predict the log variance!"
def sample_logits(self, T, input, log_var, average=True):

# Take the exponent to get the variance
variance = log_var.exp()

# Go from shape: [batch x num_classes] -> [batch x T x num_classes]
sigma = variance[:, None, :].repeat(1, T, 1)
f = input[:, None, :].repeat(1, T, 1)

# Take T samples from the Gaussian distribution
epsilon = self.Gauss.sample([input.shape[0], T])

# Multiply Gaussian noise with variance, and add to the prediction
x_i = f + (sigma * epsilon)

if average==True:
return x_i.mean(dim=1)
else:
return x_i

172 changes: 172 additions & 0 deletions project/systems/ecgresnet_mcdropout_auxout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import sys
import os
import torch
import pandas as pd
import datetime
from argparse import ArgumentParser
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split

import pytorch_lightning as pl
from pytorch_lightning.metrics import functional as FM

from network.ecgresnet_mcdropout_auxout import ECGResNet_MCDropout_AuxOutput
from utils.helpers import create_results_directory
from utils.focalloss_weights import FocalLoss

class ECGResNetMCDropout_AuxOutSystem(pl.LightningModule):

def __init__(self, in_length, in_channels, n_grps, N,
num_classes, dropout, first_width, stride,
dilation, learning_rate, n_dropout_samples, n_logit_samples, sampling_dropout_rate, loss_weights=None,
**kwargs):
super().__init__()
self.save_hyperparameters()
self.learning_rate = learning_rate
self.n_dropout_samples = n_dropout_samples
self.n_logit_samples = n_logit_samples

self.IDs = torch.empty(0).type(torch.LongTensor)
self.predicted_labels = torch.empty(0).type(torch.LongTensor)
self.correct_predictions = torch.empty(0).type(torch.BoolTensor)
self.aleatoric_uncertainty = torch.empty(0).type(torch.FloatTensor)
self.epistemic_uncertainty = torch.empty(0).type(torch.FloatTensor)
self.total_uncertainty = torch.empty(0).type(torch.FloatTensor)

self.model = ECGResNet_MCDropout_AuxOutput(in_length, in_channels,
n_grps, N, num_classes,
dropout, first_width,
stride, dilation, n_dropout_samples, sampling_dropout_rate, n_logit_samples)


if loss_weights is not None:
weights = torch.tensor(loss_weights, dtype = torch.float)
else:
weights = loss_weights

self.loss = FocalLoss(gamma=1, weights = weights)

def forward(self, x):
output1, output2_mean, output2_log_var = self.model(x)
return output1, output2_mean, output2_log_var

def training_step(self, batch, batch_idx):
"""Performs a training step.
Args:
batch (dict): Output of the dataloader.
batch_idx (int): Index no. of this batch.
Returns:
tensor: Total loss for this step.
"""
data, target = batch['waveform'], batch['label']

# Make prediction
output1, output2_mean, output2_log_var = self(data)

# Sample from logits, returning a vector x_i
x_i = self.model.sample_logits(self.n_logit_samples, output2_mean, output2_log_var, average=True)

# Apply softmax to obtain probability vector p_i
# p_i = F.softmax(x_i, dim=1)

train_loss1 = self.loss(output1, target)
train_loss2 = self.loss(x_i, target)
total_train_loss = (0.3 * train_loss1) + train_loss2

self.log('train_loss', total_train_loss)

return {'loss': total_train_loss}

def validation_step(self, batch, batch_idx):
data, target = batch['waveform'], batch['label']

# Make prediction
_, output2_mean, output2_log_var = self(data)

# Sample from logits, returning a vector x_i
x_i = self.model.sample_logits(self.n_logit_samples, output2_mean, output2_log_var, average=True)

# Apply softmax to obtain probability vector p_i
p_i = F.softmax(x_i, dim=1)

val_loss = self.loss(x_i, target)
acc = FM.accuracy(p_i, target)

# loss is tensor. The Checkpoint Callback is monitoring 'checkpoint_on'
metrics = {'val_loss': val_loss.item(), 'val_acc': acc.item()}
self.log('val_acc', acc.item())
self.log('val_loss', val_loss.item())
return metrics

def on_test_epoch_start(self):
# Enable dropout at test time.
self.model.enable_dropout()

def test_step(self, batch, batch_idx, save_to_csv=False):
data, target = batch['waveform'], batch['label']

# MC sample using dropout, sample logits for every mc-dropout sample
predictions, predictions_mean, predictions_var, log_variances_mean, predictions_mean_no_sm = self.model.mc_sample_with_sample_logits(data)

# Take exponent to get the variance
output2_var = log_variances_mean.exp()

predicted_labels = predictions_mean.argmax(dim=1)
correct_predictions = torch.eq(predicted_labels, target)

# MC dropout variance over predicted labels (epistemic uncertainty)
sampled_var = torch.gather(predictions_var, 1, predictions_mean.argmax(dim=1).unsqueeze_(1))[:, 0]

# Predicted aux-out variance of the predicted label (aleatoric uncertainty)
predicted_labels_predicted_var = torch.gather(output2_var, 1, predictions_mean.argmax(dim=1).unsqueeze_(1))[:, 0]

# Total uncertainty
total_var = predicted_labels_predicted_var + sampled_var

# Get metrics
test_loss = self.loss(predictions_mean, target)
acc = FM.accuracy(predictions_mean, target)

self.log('test_acc', acc.item())
self.log('test_loss', test_loss.item())

self.IDs = torch.cat((self.IDs, batch['id']), 0)
self.predicted_labels = torch.cat((self.predicted_labels, predicted_labels), 0)
self.correct_predictions = torch.cat((self.correct_predictions, correct_predictions), 0)
self.aleatoric_uncertainty = torch.cat((self.aleatoric_uncertainty, predicted_labels_predicted_var), 0)
self.epistemic_uncertainty = torch.cat((self.epistemic_uncertainty, sampled_var), 0)
self.total_uncertainty = torch.cat((self.total_uncertainty, total_var), 0)

return {'test_loss': test_loss.item(), 'test_acc': acc.item()}

# Initialize optimizer
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)

return optimizer

def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--model_name', type=str, default='none_auxout')
parser.add_argument('--n_logit_samples', type=int, default=100) # Number of logit samples of the auxiliary output
parser.add_argument('--n_dropout_samples', type=int, default=20) # Number of dropout samples during MCDropout sampling
parser.add_argument('--sampling_dropout_rate', type=float, default=0.1) # Dropout rate during MCDropout sampling
parser.add_argument('--ensembling_method', type=bool, default=False)
return parser

# Combine results into single dataframe and save to disk
def save_results(self):
results = pd.concat([
pd.DataFrame(self.IDs.numpy(), columns= ['ID']),
pd.DataFrame(self.predicted_labels.numpy(), columns= ['predicted_label']),
pd.DataFrame(self.correct_predictions.numpy(), columns= ['correct_prediction']),
pd.DataFrame(self.aleatoric_uncertainty.numpy(), columns= ['aleatoric_uncertainty']),
pd.DataFrame(self.epistemic_uncertainty.numpy(), columns= ['epistemic_uncertainty']),
pd.DataFrame(self.total_uncertainty.numpy(), columns= ['total_uncertainty']),
], axis=1)

create_results_directory()
results.to_csv('results/{}_{}_results.csv'.format(self.__class__.__name__, datetime.datetime.now().replace(microsecond=0).isoformat()), index=False)
Loading

0 comments on commit 308d04d

Please sign in to comment.