Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tarfiles #27

Open
wants to merge 54 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
9f3eb7b
Open MIMIC from tarfile
bganglia Aug 3, 2020
0414355
Merge branch 'master' of https://github.com/ieee8023/torchxrayvision …
bganglia Aug 6, 2020
289747d
Merge branch 'master' of https://github.com/ieee8023/torchxrayvision …
bganglia Aug 6, 2020
303832c
revert whitespace
bganglia Aug 8, 2020
da2490b
don't use get_image() in NIH_Dataset
bganglia Aug 8, 2020
8fb3f03
NIH_Dataset extends TarDataset
bganglia Aug 8, 2020
395e5e4
Store tarfiles in dictionary
bganglia Aug 8, 2020
fa69973
use getnames intead of getmembers
bganglia Aug 8, 2020
abbbfec
use O(n) method for determining imgid from tar_path
bganglia Aug 9, 2020
2ba6f5d
random data in MIMIC format
bganglia Aug 9, 2020
cacc3ad
script for generating random MIMIC data
bganglia Aug 9, 2020
ecbf302
track random MIMIC data
bganglia Aug 9, 2020
04f1a32
tarfile test using random MIMIC data
bganglia Aug 9, 2020
90129ab
fix test directory
bganglia Aug 9, 2020
0aa52a7
use .close() on tarfile and regenerate test directory
bganglia Aug 9, 2020
349babb
support for tarfiles in NIH dataset
bganglia Aug 9, 2020
6999bd3
Inherit from TarDataset in PC_Dataset
bganglia Aug 10, 2020
842ddf8
Storage-agnostic dataset
bganglia Aug 10, 2020
37afa4e
Inherit from storage agnostic loader
bganglia Aug 10, 2020
bbd4007
tidy up tarfile code
bganglia Aug 10, 2020
34daddb
remove previous TarDataset, ZipDataset classes
bganglia Aug 10, 2020
727d9ff
Scripts for generating test data
bganglia Aug 13, 2020
d2ae7c0
Test data
bganglia Aug 13, 2020
41b50c4
Tests for zip, tar in MIMIC, NIH, and PC
bganglia Aug 13, 2020
48d8170
clean up storage classes
bganglia Aug 13, 2020
5c4117e
save progress
bganglia Aug 26, 2020
2773c69
inherit from Dataset in NIH_Dataset
bganglia Aug 26, 2020
7ffc252
Add code for automated tests with script-generated data
bganglia Aug 26, 2020
68a71ae
script for writing random data
bganglia Aug 26, 2020
ec9777b
fall back on .index() instead of trying to load a cached version in .…
bganglia Aug 26, 2020
29498a6
support multiprocessing
bganglia Aug 27, 2020
3674357
Clean up new code for tests and format interfaces
bganglia Aug 27, 2020
ccec9ae
write partial metadata files with subset of columns
bganglia Aug 27, 2020
c091734
Improve caching
bganglia Aug 27, 2020
e56a565
fix tests
bganglia Aug 28, 2020
1dde4b7
fix error in data-generation script
bganglia Aug 28, 2020
1628db4
create .torchxrayvision if it does not already exist
bganglia Aug 28, 2020
124467c
fix line adding .torchxrayvision
bganglia Aug 28, 2020
28816e5
Commit sample data for testing NLM_TB datasets, instead of auto-gener…
bganglia Aug 28, 2020
ce38e57
Commit covid test cases
bganglia Aug 28, 2020
281935c
Include parallel tests again
bganglia Aug 28, 2020
9c2c9d2
trycatch on reading/writing stored_mappings, with disk_unwriteable_ou…
bganglia Aug 28, 2020
7c6aebb
work when .torchxrayvision is not writeable
bganglia Aug 28, 2020
cb97e70
remove some print statements
bganglia Aug 28, 2020
950ae96
add test simulating an unwriteable disk
bganglia Aug 28, 2020
300c9d7
use filesystem instead of dictionary
bganglia Aug 28, 2020
218fa75
rewrite data generation scripts as python, not bash scripts; add para…
bganglia Aug 30, 2020
b22cead
cleanup: better variable names and use blake2b instead of hash (works…
bganglia Aug 31, 2020
ae09bc9
Add test for asserting a dataset loads faster the second time
bganglia Aug 31, 2020
30c043b
Don't invoke duration test, to avoid spurious errors
bganglia Aug 31, 2020
bfdebf2
Call on new data generation script
bganglia Aug 31, 2020
0f7ea51
simplify and improve documentation
bganglia Sep 5, 2020
71c7a50
reorganize
bganglia Sep 19, 2020
1715b9d
Fix path length in CheX_Dataset
bganglia Sep 19, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
simplify and improve documentation
  • Loading branch information
bganglia committed Sep 5, 2020
commit 0f7ea51b19eb151bc3382744156cadd3f2649922
177 changes: 123 additions & 54 deletions torchxrayvision/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,11 @@ def __getitem__(self, idx):
return self.dataset[self.idxs[idx]]

def last_n_in_filepath(filepath, n):
"""
Return the last n pieces of a path (takes a string, not a Path object).
For example:
last_n_in_filepath("a/b/c",2) -> "b/c"
"""
if n < 1:
return ""
start_part, end_part = os.path.split(filepath)
Expand All @@ -214,66 +219,90 @@ def last_n_in_filepath(filepath, n):
end_part = os.path.join(middle_part, end_part)
return end_part

class Interface:
def get_filename_mapping_path(imgpath, path_length):
"""
This class has abstract methods for extracting files from an archive based on a partial path.
See child classes TarInterface, ZipInterface, FolderInterface, and ArchiveFolder.
Create a hash of (imgpath, last_modification, path_length_for_mapping_key)
and use it to return the filepath for a cached index.
"""
path_length = 0
def load_dataset(self, filename, save_to_cache=True, load_from_cache=True):
"Load the dataset's index from the cache if available, else create a new one."
imgpath = os.path.abspath(str(imgpath))
timestamp = os.path.getmtime(imgpath)
length = path_length
key = (imgpath, timestamp, length)

filename = os.path.abspath(str(filename))
timestamp = os.path.getmtime(filename)
length = self.path_length
key = (filename, timestamp, length)
print(key)
cache_filename = str(blake2b(pickle.dumps(key)).hexdigest()) + ".pkl"

key_filename = str(blake2b(pickle.dumps(key)).hexdigest()) + ".pkl"
file_mapping_cache_folder = os.path.expanduser(os.path.join(
"~", ".torchxrayvision", "filename-mapping-cache"
))

filename_mapping_folder = os.path.expanduser(os.path.join(
"~", ".torchxrayvision", "filename-mapping-cache"
))
filename_mapping_path = os.path.join(file_mapping_cache_folder, cache_filename)

mapping_filename = os.path.join(filename_mapping_folder, key_filename)
return filename_mapping_path

print(mapping_filename)
def load_filename_mapping(imgpath, path_length):
"If a cached filename mapping exists, return it. Otherwise, return None"

if os.path.exists(mapping_filename):
print("Loading indexed file paths from cache")
with open(mapping_filename, "rb") as handle:
mapping = pickle.load(handle)
compressed = self.get_archive(filename)
else:
print("Indexing file paths (one-time). The next load will be faster")
compressed, mapping = self.index(filename)
try:
os.makedirs(filename_mapping_folder, exist_ok=True)
with open(mapping_filename, "wb") as handle:
pickle.dump(mapping, handle)
except:
pass
return compressed, mapping
def convert_to_image(self, filename, bytes):
"Convert an image byte array to a numpy array. If the filename ends with .dcm, use pydicom."
if str(filename).endswith(".dcm"):
return pydicom.filereader.dcmread(BytesIO(bytes), force=True).pixel_array
else:
out = np.array(Image.open(BytesIO(bytes)))
return out
filename_mapping_path = get_filename_mapping_path(imgpath, path_length)

if os.path.exists(filename_mapping_path):
print("Loading indexed file paths from cache")
with open(filename_mapping_path, "rb") as handle:
filename_mapping = pickle.load(handle)
else:
filename_mapping = None

return filename_mapping

def save_filename_mapping(imgpath, path_length, filename_mapping):
"Load the dataset's index from the cache if available, else create a new one."

filename_mapping_path = get_filename_mapping_path(imgpath, path_length)

try:
#Pickle filename_mapping.
os.makedirs(os.path.dirname(filename_mapping_path), exist_ok=True)
with open(filename_mapping_path, "wb") as handle:
pickle.dump(filename_mapping, handle)
return True

except:
raise
return False
#return compressed, mapping

def convert_to_image(filename, bytes):
"Convert an image byte array to a numpy array. If the filename ends with .dcm, use pydicom."
if str(filename).endswith(".dcm"):
return pydicom.filereader.dcmread(BytesIO(bytes), force=True).pixel_array
else:
return np.array(Image.open(BytesIO(bytes)))

class Interface(object):
pass

class TarInterface(Interface):
"This class supports extracting files from a tar archive based on a partial path."
@classmethod
def matches(cls, filename):
"Return whether the given path is a tar archive."
return not os.path.isdir(filename) and tarfile.is_tarfile(filename)
def __init__(self, imgpath, path_length, save_to_cache=True, load_from_cache=True):
def __init__(self, imgpath, path_length):
"Store the archive path, and the length of the partial paths within the archive"
self.path_length = path_length
self.imgpath = imgpath
compressed, self.filename_mapping = self.load_dataset(imgpath, save_to_cache, load_from_cache)
self.all_compressed = {multiprocessing.current_process():compressed}

#Load archive and filename mapping
compressed = None
self.filename_mapping = load_filename_mapping(imgpath, path_length)
#If the filename mapping could not be loaded, create it and save it
if self.filename_mapping is None:
compressed, self.filename_mapping = self.index(imgpath)
save_filename_mapping(imgpath, path_length, self.filename_mapping)
#If the compressed file has still not been loaded, load it.
if compressed is None:
compressed = self.get_archive(imgpath)
self.all_compressed = {multiprocessing.current_process().name:compressed}

def get_image(self, imgid):
"Return the image object for the partial path provided."
archive_path = self.filename_mapping[imgid]
Expand All @@ -282,13 +311,14 @@ def get_image(self, imgid):
# check and reset number of open files if too many
if len(self.all_compressed.keys()) > 64:
self.all_compressed = {}
self.all_compressed[multiprocessing.current_process().name] = self.get_archive(self.imgpath)
self.all_compressed[multiprocessing.current_process().name] = tarfile.open(self.imgpath)
bytes = self.all_compressed[multiprocessing.current_process().name].extractfile(archive_path).read()
return self.convert_to_image(archive_path, bytes)
return convert_to_image(archive_path, bytes)
def get_archive(self, imgpath):
return tarfile.open(imgpath)
def index(self, imgpath):
"Create a dictionary mapping imgpath -> path within archive"
print("Indexing file paths (one-time). The next load will be faster")
compressed = tarfile.open(imgpath)
tar_infos = compressed.getmembers()
filename_mapping = {}
Expand All @@ -309,12 +339,23 @@ class ZipInterface(Interface):
def matches(cls, filename):
"Return whether the given path is a zip archive."
return not os.path.isdir(filename) and zipfile.is_zipfile(filename)
def __init__(self, imgpath, path_length, save_to_cache=True, load_from_cache=True):
def __init__(self, imgpath, path_length):
"Store the archive path, and the length of the partial paths within the archive"
self.path_length = path_length
self.imgpath = imgpath
compressed, self.filename_mapping = self.load_dataset(imgpath, save_to_cache, load_from_cache)

#Load archive and filename mapping
compressed = None
self.filename_mapping = load_filename_mapping(imgpath, path_length)
#If the filename mapping could not be loaded, create it and save it
if self.filename_mapping is None:
compressed, self.filename_mapping = self.index(imgpath)
save_filename_mapping(imgpath, path_length, self.filename_mapping)
#If the compressed file has still not been loaded, load it.
if compressed is None:
compressed = zipfile.ZipFile(imgpath)
self.all_compressed = {multiprocessing.current_process().name:compressed}

def get_image(self, imgid):
"Return the image object for the partial path provided."
archive_path = self.filename_mapping[imgid]
Expand All @@ -325,11 +366,12 @@ def get_image(self, imgid):
self.all_compressed = {}
self.all_compressed[multiprocessing.current_process().name] = zipfile.ZipFile(self.imgpath)
bytes = self.all_compressed[multiprocessing.current_process().name].open(archive_path).read()
return self.convert_to_image(archive_path, bytes)
return convert_to_image(archive_path, bytes)
def get_archive(self, imgpath):
return zipfile.ZipFile(imgpath)
def index(self, imgpath):
"Create a dictionary mapping imgpath -> path within archive"
print("Indexing file paths (one-time). The next load will be faster")
compressed = zipfile.ZipFile(imgpath)
zip_infos = compressed.infolist()
filename_mapping = {}
Expand All @@ -346,24 +388,33 @@ def close(self):

class FolderInterface(Interface):
"This class supports drawing files from a folder based on a partial path."

@classmethod
def matches(cls, filename):
"Return whether the given path is a zip archive."
return os.path.isdir(filename)
def __init__(self, imgpath, path_length, save_to_cache=True, load_from_cache=True):

def __init__(self, imgpath, path_length):
"Store the archive path, and the length of the partial paths within the archive"
self.path_length = path_length
self.path, self.filename_mapping = self.load_dataset(imgpath, save_to_cache, load_from_cache)

self.filename_mapping = load_filename_mapping(imgpath, path_length)
#If the filename mapping could not be loaded, create it and save it
if self.filename_mapping is None:
_, self.filename_mapping = self.index(imgpath)
save_filename_mapping(imgpath, path_length, self.filename_mapping)

def get_archive(self, imgid):
pass
def get_image(self, imgid):
"Return the image object for the partial path provided."
archive_path = self.filename_mapping[imgid]
with open(archive_path,"rb") as handle:
image = self.convert_to_image(archive_path, handle.read())
image = convert_to_image(archive_path, handle.read())
return image
def index(self, imgpath):
"Create a dictionary mapping imgpath -> path within archive"
print("Indexing file paths (one-time). The next load will be faster")
filename_mapping = {}
for path in Path(imgpath).rglob("*"):
if not os.path.isdir(path):
Expand All @@ -384,8 +435,10 @@ def is_archive(filename):
"Return whether the given filename is a tarfile or zipfile."
return any(interface.matches(filename) for interface in archive_interfaces)


class ArchiveFolder(Interface):
"This class supports extracting files from multiple tar or zip archives under the same root directory."

@classmethod
def matches(cls, filename):
for item in Path(filename).rglob("*"):
Expand All @@ -394,20 +447,34 @@ def matches(cls, filename):
if is_archive(item):
return True
return False
def __init__(self, imgpath, path_length, save_to_cache=True, load_from_cache=True):

def __init__(self, imgpath, path_length):
"Store the archive path, and the length of the partial paths within the archive"
self.path_length = path_length
self.archives, self.filename_mapping = self.load_dataset(imgpath, save_to_cache, load_from_cache)
self.archives = None
self.filename_mapping = load_filename_mapping(imgpath, path_length)
#If the filename mapping could not be loaded, create it and save it
if self.filename_mapping is None:
self.archives, self.filename_mapping = self.index(imgpath)
save_filename_mapping(imgpath, path_length, self.filename_mapping)
#If the compressed file has still not been loaded, load it.
if self.archives is None:
self.archives = self.get_archive(imgpath)

def get_archive(self, imgpath):
return get_interface(imgpath)

def get_image(self, imgid):
"Return the image object for the partial path provided."
path_to_archive = self.filename_mapping[imgid]
return self.archives[path_to_archive].get_image(imgid)

def index(self, filename):
"""
Create a dictionary mapping imgid -> path to the sub-archive
containing the corresponding file.
Create a dictionary mapping imgid -> containing sub-archive.
The archives are identified by their filenames. The sub-archive
will then be queried itself.

This is different from the index method of ZipInterface and
TarInterface, where the dictionary values are the actual file
paths.
Expand All @@ -418,13 +485,15 @@ def index(self, filename):
for path_in_csv, path_in_archive in archive.filename_mapping.items():
filename_mapping[path_in_csv] = path_to_archive
return archives, filename_mapping

def get_archive(self, filename):
archives = {}
for path_to_archive in Path(filename).rglob("*"):
if is_archive(path_to_archive):
archive = create_interface(path_to_archive, self.path_length)
archives[path_to_archive] = archive
return archives

def close(self):
"Recursively close all open archives."
for archive_path, archive in self.archives.items():
Expand Down