Source code for homura.vision.data.segmentation

import pathlib

from torchvision import datasets as VD


[docs]class ExtendedVOCSegmentation(object): def __new__(cls, root, train=True, transform=None, download=False): root = pathlib.Path(root).parent if train: return VD.SBDataset(root / "sbd", image_set='train_noval', mode='segmentation', transforms=transform, download=download) else: return VD.VOCSegmentation(root / "voc", image_set="val", transforms=transform, download=download) def __len__(self): ...
# from torchvision's reference
[docs]def cat_list(images, fill_value=0): max_size = tuple(max(s) for s in zip(*[img.shape for img in images])) batch_shape = (len(images),) + max_size batched_imgs = images[0].new(*batch_shape).fill_(fill_value) for img, pad_img in zip(images, batched_imgs): pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img) return batched_imgs
[docs]def seg_collate_fn(batch): images, targets = list(zip(*batch)) batched_imgs = cat_list(images, fill_value=0) batched_targets = cat_list(targets, fill_value=255) return batched_imgs, batched_targets