Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Setting NUM_SAMPLES when using sampler with Queue #1108

Open
nicoloesch opened this issue Sep 20, 2023 · 3 comments
Open

Setting NUM_SAMPLES when using sampler with Queue #1108

nicoloesch opened this issue Sep 20, 2023 · 3 comments
Labels
enhancement New feature or request

Comments

@nicoloesch
Copy link
Contributor

nicoloesch commented Sep 20, 2023

🚀 Feature
If a torchio.Sampler is used in combination with a torchio.Queue, the Queue requests the NUM_SAMPLES attribute of each torchio.Subject in _fill.

def _fill(self) -> None:
        assert self.sampler is not None

        if self._incomplete_subject is not None:
            subject = self._incomplete_subject
            iterable = self.sampler(subject)
            patches = list(islice(iterable, self._num_patches_incomplete))
            self.patches_list.extend(patches)
            self._incomplete_subject = None

        while True:
            subject = self._get_next_subject()
            iterable = self.sampler(subject)
            num_samples = self._get_subject_num_samples(subject)  <- HERE IS THE ATTRIBUTE CALL
            num_free_slots = self.max_length - len(self.patches_list)
            if num_free_slots < num_samples:
                self._incomplete_subject = subject
                self._num_patches_incomplete = num_samples - num_free_slots
            num_samples = min(num_samples, num_free_slots)
            patches = list(islice(iterable, num_samples))
            self.patches_list.extend(patches)
            self._num_sampled_subjects += 1
            list_full = len(self.patches_list) >= self.max_length
            all_sampled = self._num_sampled_subjects >= self.num_subjects
            if list_full or all_sampled:
                break

However, usually the max number of samples/patches per subject is dependent on the different augmentations performed and subsequently is reflected by the number of non-zero entries in the probability map, which is processed by the cdf to yield patches in _generate_patches of the respective sampler. As a result, the number of samples is only known AFTER calculating the probability_map - the current implementation of the Queue however requests the attribute PRIOR to knowing the number of samples.
If one would rewrite __call__ (sampler), _generate_patches (sampler) and _get_subject_num_samples (Queue) (shown in the following), one could obtain the probability_map prior to creating the generator in _generate_patches and therefore set the num_samples prior to the Queue requesting the attribute.

Motivation

Allowing the user to set/ automatically setting the number of samples retrieved from each subject makes the Queue more robust, functional, and alleviates the sampling of duplicates (e.g. the probability_map only has 5 allowed patches but the user requested 10 -> each one is sampled approx. twice).

Pitch

Rewrite __call__ of torchio.Sampler to the following:

def __call__(
            self, 
            subject: Subject, 
            num_patches: Optional[int] = None) -> Generator[Subject, None, None]:
        
        subject.check_consistent_space()
        if np.any(self.patch_size > subject.spatial_shape):
            message = (
                f'Patch size {tuple(self.patch_size)} cannot be'
                f' larger than image size {tuple(subject.spatial_shape)}'
            )
            raise RuntimeError(message)
        
        probability_map = self.get_probability_map(subject)
        num_max_patches = int(torch.count_nonzero(probability_map))
        setattr(subject, NUM_SAMPLES, num_max_patches)
        
        # This is optional
        if num_patches is None:
            num_patches = getattr(subject, NUM_SAMPLES)
        return self._generate_patches(subject, probability_map, num_patches)

Rewrite _generate_patches of the samplers to the following (in my example it is weighted.py sampler but needs to be done accordingly if the method is overwritten in other samplers):

def _generate_patches(
        self,
        subject: Subject,
        probability_map: torch.Tensor,
        num_patches: Optional[int] = None,
    ) -> Generator[Subject, None, None]:
        # Only removes the call to calculating the probability map here        
        probability_map_array = self.process_probability_map(
            probability_map,
            subject,
        )
        cdf = self.get_cumulative_distribution_function(probability_map_array)

        patches_left = num_patches if num_patches is not None else True
        while patches_left:
            yield self.extract_patch(subject, probability_map_array, cdf)
            if num_patches is not None:
                patches_left -= 1

And finally adapt the method _get_subject_num_samples of torchio.Queue to:

def _get_subject_num_samples(self, subject):
        num_samples = getattr(
            subject,
            NUM_SAMPLES,
            self.samples_per_volume,
        )
        return min(num_samples, self.samples_per_volume)  <- Prevents sampling of more patches than there are in a subject

Alternatives

The highlighted section in __call__ should be kept in to prevent an endless loop in the case of num_patches=None in _generate_patches. As an alternative, one could force to have num_patches set to an integer in any case (I can't think of a scenario of endless sampling), i.e. remove the Optional and test for is not None.

Remarks
The layout of the Queue might change depending on the outcome of #1096. Furthermore, there needs to be a method if a sample has zero available patches (for whatever reasons). Currently, I ensure that this does not happen in get_probability_map but the entire thing might break down if a subject has zero patches (not tested by me as of now).

@nicoloesch nicoloesch added the enhancement New feature or request label Sep 20, 2023
@romainVala
Copy link
Contributor

Hello

I am curious about the cases where it happen.
I work with full brain MRI segmentation and there the probability map are often very large. (ie number of possible patche center << NUM_SAMPLE

but let's imagine I want to focus on a small region with only five (connected) voxels. If I choose large enough patch size it seems to me that event the five possible distinct patches (each center on the five voxel of my cdf) will already be very similar. So taking only five patches will not solve the issue to have almost identical patches ... no ?

may be your proposition makes sense if the five voxel are spatially distinct ... but it looks weird to me to have single voxel regions ...

@nicoloesch
Copy link
Contributor Author

Hi,

The reason this came up is because I am using torchio for 2D samples, with each slice representing a patch in the classical sense. As a result, I sometimes have less available slices than the patches_per_subj set upfront. I am aware that this case is not usually encountered but it raises then the question: Why have that mechanism in place if it is not checked anyways? Why not remove the entire NUM_SAMPLES attribute alltogether if it is not set at all or at the wrong time?
For my application it makes sense to set the attribute but I understand this is not always the case. However, if the check and system is already there, why not use it in the way it was intended. I am happy to utilise my own classes that overwrite this functionality but the question still remains why the mechanism is there but not being used.

@romainVala
Copy link
Contributor

I was just questioning the use case, but I now better understand your's so it makes senses

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants