Skip to content

Commit

Permalink
First commit.
Browse files Browse the repository at this point in the history
  • Loading branch information
gholste committed Aug 23, 2022
1 parent 235bf37 commit 3e87242
Show file tree
Hide file tree
Showing 9 changed files with 112,798 additions and 0 deletions.
146 changes: 146 additions & 0 deletions datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import os

import cv2
import numpy as np
import pandas as pd
import torch
import torchvision

class NIH_CXR_Dataset(torch.utils.data.Dataset):
def __init__(self, data_dir, label_dir, split):
self.data_dir = data_dir
self.split = split

self.CLASSES = [
'No Finding', 'Infiltration', 'Atelectasis', 'Effusion', 'Nodule',
'Mass', 'Pneumothorax', 'Consolidation', 'Pleural_Thickening',
'Cardiomegaly', 'Fibrosis', 'Edema', 'Tortuous Aorta', 'Emphysema',
'Pneumonia', 'Calcification of the Aorta', 'Pneumoperitoneum', 'Hernia',
'Subcutaneous Emphysema', 'Pneumomediastinum'
]

self.label_df = pd.read_csv(os.path.join(label_dir, f'nih-lt_single-label_{split}.csv'))

self.img_paths = self.label_df['id'].apply(lambda x: os.path.join(data_dir, x)).values.tolist()
self.labels = self.label_df[self.CLASSES].idxmax(axis=1).apply(lambda x: self.CLASSES.index(x)).values

self.cls_num_list = self.label_df[self.CLASSES].sum(0).values.tolist()

if self.split == 'train':
self.transform = torchvision.transforms.Compose([
torchvision.transforms.ToPILImage(),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.RandomRotation(15),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) )
])
else:
self.transform = torchvision.transforms.Compose([
torchvision.transforms.ToPILImage(),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) )
])

def __len__(self):
return len(self.img_paths)

def __getitem__(self, idx):
x = cv2.imread(self.img_paths[idx])
x = cv2.resize(x, (256, 256), interpolation=cv2.INTER_AREA)

x = self.transform(x)

y = np.array(self.labels[idx])

return x.float(), torch.from_numpy(y).long()

class MIMIC_CXR_Dataset(torch.utils.data.Dataset):
def __init__(self, data_dir, label_dir, split):
self.split = split

self.CLASSES = [
'No Finding', 'Lung Opacity', 'Cardiomegaly', 'Atelectasis',
'Pleural Effusion', 'Support Devices', 'Edema', 'Pneumonia',
'Pneumothorax', 'Lung Lesion', 'Fracture', 'Enlarged Cardiomediastinum',
'Consolidation', 'Pleural Other', 'Calcification of the Aorta',
'Tortuous Aorta', 'Pneumoperitoneum', 'Subcutaneous Emphysema',
'Pneumomediastinum'
]

self.label_df = pd.read_csv(os.path.join(label_dir, f'mimic-lt_single-label_{split}.csv'))

self.img_paths = self.label_df['path'].apply(lambda x: os.path.join(data_dir, x)).values.tolist()
self.labels = self.label_df[self.CLASSES].idxmax(axis=1).apply(lambda x: self.CLASSES.index(x)).values

self.cls_num_list = self.label_df[self.CLASSES].sum(0).values.tolist()

if self.split == 'train':
self.transform = torchvision.transforms.Compose([
torchvision.transforms.ToPILImage(),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.RandomRotation(15),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) )
])
else:
self.transform = torchvision.transforms.Compose([
torchvision.transforms.ToPILImage(),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) )
])

def __len__(self):
return len(self.img_paths)

def __getitem__(self, idx):
x = cv2.imread(self.img_paths[idx])
x = cv2.resize(x, (256, 256), interpolation=cv2.INTER_AREA)
x = self.transform(x)

y = np.array(self.labels[idx])

return x.float(), torch.from_numpy(y).long()

## CREDIT TO https://github.com/agaldran/balanced_mixup ##

# pytorch-wrapping-multi-dataloaders/blob/master/wrapping_multi_dataloaders.py
class ComboIter(object):
"""An iterator."""
def __init__(self, my_loader):
self.my_loader = my_loader
self.loader_iters = [iter(loader) for loader in self.my_loader.loaders]

def __iter__(self):
return self

def __next__(self):
# When the shortest loader (the one with minimum number of batches)
# terminates, this iterator will terminates.
# The `StopIteration` raised inside that shortest loader's `__next__`
# method will in turn gets out of this `__next__` method.
batches = [loader_iter.next() for loader_iter in self.loader_iters]
return self.my_loader.combine_batch(batches)

def __len__(self):
return len(self.my_loader)

class ComboLoader(object):
"""This class wraps several pytorch DataLoader objects, allowing each time
taking a batch from each of them and then combining these several batches
into one. This class mimics the `for batch in loader:` interface of
pytorch `DataLoader`.
Args:
loaders: a list or tuple of pytorch DataLoader objects
"""
def __init__(self, loaders):
self.loaders = loaders

def __iter__(self):
return ComboIter(self)

def __len__(self):
return min([len(loader) for loader in self.loaders])

# Customize the behavior of combining batches here.
def combine_batch(self, batches):
return batches
Loading

0 comments on commit 3e87242

Please sign in to comment.