-
-
Notifications
You must be signed in to change notification settings - Fork 211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tarfiles #27
base: master
Are you sure you want to change the base?
Tarfiles #27
Changes from 5 commits
9f3eb7b
0414355
289747d
303832c
da2490b
8fb3f03
395e5e4
fa69973
abbbfec
2ba6f5d
cacc3ad
ecbf302
04f1a32
90129ab
0aa52a7
349babb
6999bd3
842ddf8
37afa4e
bbd4007
34daddb
727d9ff
d2ae7c0
41b50c4
48d8170
5c4117e
2773c69
7ffc252
68a71ae
ec9777b
29498a6
3674357
ccec9ae
c091734
e56a565
1dde4b7
1628db4
124467c
28816e5
ce38e57
281935c
9c2c9d2
7c6aebb
cb97e70
950ae96
300c9d7
218fa75
b22cead
ae09bc9
30c043b
bfdebf2
0f7ea51
71c7a50
1715b9d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,8 @@ | |
from torchvision import transforms | ||
from tqdm import tqdm | ||
import numpy as np | ||
from io import BytesIO | ||
import tarfile | ||
import os,sys,os.path | ||
import pandas as pd | ||
import pickle | ||
|
@@ -195,7 +197,24 @@ def __len__(self): | |
def __getitem__(self, idx): | ||
return self.dataset[self.idxs[idx]] | ||
|
||
|
||
|
||
class TarDataset(Dataset): | ||
def __init__(self, imgpath): | ||
if imgpath.endswith(".tar"): | ||
self.tarred = tarfile.open(imgpath) | ||
self.tar_paths = self.tarred.getmembers() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So you said this takes a lot of time. I see how this approach is super robust. I did some tests for time and it seems like just 30 seconds for the MIMIC data. What about caching this using a dict based on the file path? It would speed things up if multiple objects are created? But it seems like a reasonable price to pay. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, using a dict, the second load is around 10x faster on my machine. The dict could also be pickled so there is only one slow load. That option would lead to issues if someone wanted to change the tarfile, although I don't know why they would do that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doing hashing and caching a file could be nice but also could be annoying to debug and create more issues (like if there are no write permissions). |
||
else: | ||
self.tarred = None | ||
def get_image(self, path): | ||
if self.tarred is None: | ||
return imread(os.path.join(self.imgpath, path)) | ||
else: | ||
for tar_path in self.tar_paths: | ||
name = tar_path.name | ||
if name.endswith(path): | ||
bytes = self.tarred.extractfile(name).read() | ||
return np.array(Image.open(BytesIO(bytes))) | ||
|
||
class NIH_Dataset(Dataset): | ||
""" | ||
NIH ChestX-ray8 dataset | ||
|
@@ -293,7 +312,7 @@ def __getitem__(self, idx): | |
|
||
|
||
imgid = self.csv['Image Index'].iloc[idx] | ||
img_path = os.path.join(self.imgpath, imgid) | ||
#img_path = os.path.join(self.imgpath, imgid) | ||
#print(img_path) | ||
img = imread(img_path) | ||
if self.normalize: | ||
|
@@ -838,7 +857,7 @@ def __getitem__(self, idx): | |
|
||
return {"img":img, "lab":self.labels[idx], "idx":idx} | ||
|
||
class MIMIC_Dataset(Dataset): | ||
class MIMIC_Dataset(TarDataset): | ||
""" | ||
Johnson AE, Pollard TJ, Berkowitz S, Greenbaum NR, Lungren MP, Deng CY, Mark RG, Horng S. MIMIC-CXR: A large publicly available database of labeled chest radiographs. arXiv preprint arXiv:1901.07042. 2019 Jan 21. | ||
|
||
|
@@ -850,7 +869,7 @@ class MIMIC_Dataset(Dataset): | |
def __init__(self, imgpath, csvpath,metacsvpath, views=["PA"], transform=None, data_aug=None, | ||
flat_dir=True, seed=0, unique_patients=True): | ||
|
||
super(MIMIC_Dataset, self).__init__() | ||
super(MIMIC_Dataset, self).__init__(imgpath) | ||
np.random.seed(seed) # Reset the seed so all runs are the same. | ||
self.MAXVAL = 255 | ||
|
||
|
@@ -877,8 +896,10 @@ def __init__(self, imgpath, csvpath,metacsvpath, views=["PA"], transform=None, d | |
self.csv = pd.read_csv(self.csvpath) | ||
self.metacsvpath = metacsvpath | ||
self.metacsv = pd.read_csv(self.metacsvpath) | ||
|
||
|
||
|
||
self.csv = self.csv.set_index(['subject_id', 'study_id']) | ||
|
||
self.metacsv = self.metacsv.set_index(['subject_id', 'study_id']) | ||
|
||
self.csv = self.csv.join(self.metacsv).reset_index() | ||
|
@@ -926,9 +947,9 @@ def __getitem__(self, idx): | |
studyid = str(self.csv.iloc[idx]["study_id"]) | ||
dicom_id = str(self.csv.iloc[idx]["dicom_id"]) | ||
|
||
img_path = os.path.join(self.imgpath, "p" + subjectid[:2], "p" + subjectid, "s" + studyid, dicom_id + ".jpg") | ||
img = imread(img_path) | ||
img = normalize(img, self.MAXVAL) | ||
img_fname = os.path.join("p" + subjectid[:2], "p" + subjectid, "s" + studyid, dicom_id + ".jpg") | ||
img = self.get_image(img_fname) | ||
img = normalize(img, self.MAXVAL) | ||
|
||
# Check that images are 2D arrays | ||
if len(img.shape) > 2: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about using tarfile.is_tarfile ? This will allow for compressed files which people may want to use.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, thanks