Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save and load from disque sample in ImageDataset #108

Closed
romainVala opened this issue Apr 1, 2020 · 8 comments
Closed

save and load from disque sample in ImageDataset #108

romainVala opened this issue Apr 1, 2020 · 8 comments

Comments

@romainVala
Copy link
Contributor

Hi there,

Is your feature request related to a problem? Please describe.
training with random motion is taking to much time, even with 20 numworker (and 180 G of ram) I am not quick enough to occupy the gpu during the training. and when testing different model parameter, it is just insane to require so much computation ...
any way it is just to slow

Describe the solution you'd like
The solution is to compute the transform samples first then save them to disk, and then allow a method in ImageDataset to load the sample from disk
You may think you will not gain time : because you need to first save the sample to disk

but you gain if you want to test different model, and if you have access to a cluster. (that allow a very fast sample generation)

The solution
it is indeed very simple (and efficient ) to implement. I could try a PR if you are interested, but since it is a small change I just copy to modified code from ImageDataset

def __init__(
        self,
        subjects: Sequence[Subject],
        transform: Optional[Callable] = None,
        check_nans: bool = True,
        save_to_dir = None,
        load_from_dir = None,
        load_image_data: bool = True,
        ):
    self.load_from_dir = load_from_dir
    if not load_from_dir:
        self._parse_subjects_list(subjects)
    self.subjects = subjects
   ...

def __getitem__(self, index: int) -> dict:
    if not isinstance(index, int):
        raise TypeError(f'Index "{index}" must be int, not {type(index)}')

    if self.load_from_dir:
        sample = torch.load(self.subjects[index])
    else:
        subject = self.subjects[index]
        sample = self.get_sample_dict_from_subject(subject)

    # Apply transform (this is usually the bottleneck)
    if self._transform is not None:
        sample = self._transform(sample)

    if self.save_to_dir is not None:
        res_dir = self.save_to_dir
        fname = res_dir + '/sample{:05d}'.format(index)
        if 'image_orig' in sample: sample.pop('image_orig')
        torch.save(sample, fname + '_sample.pt')

    return sample

Note the good part : you can still apply some transform even after loading from the disk (very convenient for quick transform)

Personally I did not use the save_to_dir argument, because I implement it outside, to properly handle the index, in a cluster case (ie multiple intense running with different dataset subpar, but then the same index ...)
but if you do it locally it is working fine

I hope it helps

@fepegar
Copy link
Owner

fepegar commented Apr 1, 2020

Hi @romainVala,

Thanks for reporting this. Isn't this just offline data augmentation? Can't you use ImagesDataset.save_sample or torchio.utils.apply_transform_to_file or the CLI tool torchio-tranform to generate all the samples you want and then instantiate the ImagesDataset using the transformed files?

@romainVala
Copy link
Contributor Author

It is not exactly the same what you propose is to save the transform as nifti file, I save the sample dictionary structure, (since I need, some special keys written by the transform during training)

@fepegar
Copy link
Owner

fepegar commented Apr 1, 2020

By "special keys" do you mean the random parameters?

@romainVala
Copy link
Contributor Author

yes
Actually I use my own class RandomMotionFromTimeCourse where I add some specirfic metrics :
when applying a motion transform it also compute the similarity with the original data, and this is what I try to learn ...

@fepegar
Copy link
Owner

fepegar commented Apr 1, 2020

I don't love the idea of pickling dictionaries. What about saving those parameters in a text file and creating a dataset that inherits from ImagesDataset as shown in this example: eb0244c

@romainVala
Copy link
Contributor Author

note sure I understand what pickling is ... , I just do a torch.save ... ? (so simple )

for the example you link, i do not see how it solve the problem, which is (if I follow correctly): adding extra informations to the images dictionary (from a csv file for instance ?)
in this exemple you do not create any class ?

@fepegar
Copy link
Owner

fepegar commented Apr 1, 2020

note sure I understand what pickling is ... , I just do a torch.save ... ? (so simple )

torch.save uses the Python pickle module to save and load serialized objects. I don't know the theory behind, but I guess it's a bit like saving binary files into disk, without any specific format, only readable again by Python in a very specific context.

for the example you link, i do not see how it solve the problem, which is (if I follow correctly): adding extra informations to the images dictionary (from a csv file for instance ?)
in this exemple you do not create any class ?

Sorry, I meant the example in the commit message:

class MyDataset(torchio.ImagesDataset):
    def get_image_dict_from_image(self, image):
        image_dict = super().get_image_dict_from_image(image)
        subject_id = image.path.name.split('_')[0]
        image_dict['subject_id'] = subject_id
        return image_dict

You could, for example, do:

class RandomMotionDataset(torchio.ImagesDataset):
    def get_image_dict_from_image(self, image):  # overrides ImagesDataset.get_image_dict_from_image
        image_dict = super().get_image_dict_from_image(image)  # standard image_dict
        motion_parameters = get_motion_parameters(image.path)
        image_dict['random_motion'] = motion_parameters
        return image_dict

    def get_motion_parameters(self, path):
        parameters_path = path.parent / path.name.replace('.nii.gz', '.json')
        parameters = read_json(parameters_path)  # defined somewhere else
        return parameters

@fepegar
Copy link
Owner

fepegar commented May 5, 2020

Closing as now there is a history attribute in Subject from which random parameters can be retrieved. Feel free to reopen if needed.

@fepegar fepegar closed this as completed May 5, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants