Skip to content

Commit

Permalink
ef35a38ad29c@2024-03-28_09-57-57: add ape-ti with vit-ti backbone, su…
Browse files Browse the repository at this point in the history
…pport fsdp and vit-e
  • Loading branch information
shenyunhang committed Mar 28, 2024
1 parent 2e8fc59 commit 358e86e
Show file tree
Hide file tree
Showing 57 changed files with 4,266 additions and 74 deletions.
1 change: 1 addition & 0 deletions ape/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@


from .detection_checkpoint import DetectionCheckpointer
from .detection_checkpoint import FSDPDetectionCheckpointer

__all__ = ["DetectionCheckpointer"]
48 changes: 47 additions & 1 deletion ape/checkpoint/detection_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import logging
import os
import pickle
from collections import defaultdict
from typing import IO, Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, cast

import numpy as np
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from torch.distributed.fsdp import FullStateDictConfig

from detectron2.checkpoint import DetectionCheckpointer as DetectionCheckpointer_d2

Expand Down Expand Up @@ -43,3 +45,47 @@ def _convert_ndarray_to_tensor(self, state_dict: Dict[str, Any]) -> None:
raise ValueError("Unsupported type found in checkpoint! {}: {}".format(k, type(v)))
if not isinstance(v, torch.Tensor):
state_dict[k] = torch.from_numpy(v)


class FSDPDetectionCheckpointer(DetectionCheckpointer):

# def __init__(self, skip_key="", **kwargs):
# super().__init__(**kwargs)
# self.skip_key = skip_key

def save(self, name: str, **kwargs: Any) -> None:
"""
Dump model and checkpointables to a file.
Args:
name (str): name of the file.
kwargs (dict): extra arbitrary data to save.
"""
# if not self.save_dir or not self.save_to_disk:
# return

data = {}

save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(
self.model, StateDictType.FULL_STATE_DICT, save_policy
):
data["model"] = self.model.state_dict()

if not self.save_dir or not self.save_to_disk:
return

# data["model"] = self.model.state_dict()
for key, obj in self.checkpointables.items():
data[key] = obj.state_dict()
data.update(kwargs)

basename = "{}.pth".format(name)
save_file = os.path.join(self.save_dir, basename)
assert os.path.basename(save_file) == basename, basename
self.logger.info("Saving checkpoint to {}".format(save_file))
with self.path_manager.open(save_file, "wb") as f:
# pyre-fixme[22]: The cast is redundant.
torch.save(data, cast(IO[bytes], f))
self.tag_last_checkpoint(basename)

1 change: 1 addition & 0 deletions ape/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
build_detection_train_loader_multi_dataset_copypaste,
get_detection_dataset_dicts_multi_dataset_copypaste,
)
from .build import build_detection_test_loader
from .dataset_mapper import DatasetMapper_ape
from .dataset_mapper_copypaste import DatasetMapper_copypaste
from .dataset_mapper_detr_instance import DatasetMapper_detr_instance
Expand Down
135 changes: 135 additions & 0 deletions ape/data/build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright (c) Facebook, Inc. and its affiliates.
import itertools
import logging
import numpy as np
import operator
import pickle
from typing import Any, Callable, Dict, List, Optional, Union
import torch
import torch.utils.data as torchdata
from tabulate import tabulate
from termcolor import colored

from detectron2.config import configurable
from detectron2.structures import BoxMode
from detectron2.utils.comm import get_world_size
from detectron2.utils.env import seed_all_rng
from detectron2.utils.file_io import PathManager
from detectron2.utils.logger import _log_api_usage, log_first_n

from detectron2.data.build import trivial_batch_collator

from detectron2.data.common import AspectRatioGroupedDataset, DatasetFromList, MapDataset, ToIterableDataset
from detectron2.data.dataset_mapper import DatasetMapper
from detectron2.data.detection_utils import check_metadata_consistency
from detectron2.data.samplers import (
RandomSubsetTrainingSampler,
RepeatFactorTrainingSampler,
TrainingSampler,
)

from .samplers import (
InferenceSampler,
)

"""
This file contains the default logic to build a dataloader for training or testing.
"""

__all__ = [
"build_detection_test_loader",
]


def _test_loader_from_config(cfg, dataset_name, mapper=None):
"""
Uses the given `dataset_name` argument (instead of the names in cfg), because the
standard practice is to evaluate each test set individually (not combining them).
"""
if isinstance(dataset_name, str):
dataset_name = [dataset_name]

dataset = get_detection_dataset_dicts(
dataset_name,
filter_empty=False,
proposal_files=[
cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)] for x in dataset_name
]
if cfg.MODEL.LOAD_PROPOSALS
else None,
)
if mapper is None:
mapper = DatasetMapper(cfg, False)
return {
"dataset": dataset,
"mapper": mapper,
"num_workers": cfg.DATALOADER.NUM_WORKERS,
"sampler": InferenceSampler(len(dataset))
if not isinstance(dataset, torchdata.IterableDataset)
else None,
}


@configurable(from_config=_test_loader_from_config)
def build_detection_test_loader(
dataset: Union[List[Any], torchdata.Dataset],
*,
mapper: Callable[[Dict[str, Any]], Any],
sampler: Optional[torchdata.Sampler] = None,
batch_size: int = 1,
num_workers: int = 0,
collate_fn: Optional[Callable[[List[Any]], Any]] = None,
) -> torchdata.DataLoader:
"""
Similar to `build_detection_train_loader`, with default batch size = 1,
and sampler = :class:`InferenceSampler`. This sampler coordinates all workers
to produce the exact set of all samples.
Args:
dataset: a list of dataset dicts,
or a pytorch dataset (either map-style or iterable). They can be obtained
by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
mapper: a callable which takes a sample (dict) from dataset
and returns the format to be consumed by the model.
When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``.
sampler: a sampler that produces
indices to be applied on ``dataset``. Default to :class:`InferenceSampler`,
which splits the dataset across all workers. Sampler must be None
if `dataset` is iterable.
batch_size: the batch size of the data loader to be created.
Default to 1 image per worker since this is the standard when reporting
inference time in papers.
num_workers: number of parallel data loading workers
collate_fn: same as the argument of `torch.utils.data.DataLoader`.
Defaults to do no collation and return a list of data.
Returns:
DataLoader: a torch DataLoader, that loads the given detection
dataset, with test-time transformation and batching.
Examples:
::
data_loader = build_detection_test_loader(
DatasetRegistry.get("my_test"),
mapper=DatasetMapper(...))
# or, instantiate with a CfgNode:
data_loader = build_detection_test_loader(cfg, "my_test")
"""
if isinstance(dataset, list):
dataset = DatasetFromList(dataset, copy=False)
if mapper is not None:
dataset = MapDataset(dataset, mapper)
if isinstance(dataset, torchdata.IterableDataset):
assert sampler is None, "sampler must be None if dataset is IterableDataset"
else:
if sampler is None:
sampler = InferenceSampler(len(dataset))
return torchdata.DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
drop_last=False,
num_workers=num_workers,
collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
)
3 changes: 2 additions & 1 deletion ape/data/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .distributed_sampler_multi_dataset import MultiDatasetTrainingSampler
from .distributed_sampler_multi_dataset import MultiDatasetTrainingSampler, InferenceSampler

__all__ = [
"MultiDatasetTrainingSampler",
"InferenceSampler",
]
39 changes: 39 additions & 0 deletions ape/data/samplers/distributed_sampler_multi_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,42 @@ def _infinite_indices(self):
yield from indices[randperm].tolist()
else:
yield from indices.tolist()


class InferenceSampler(Sampler):
"""
Produce indices for inference across all workers.
Inference needs to run on the __exact__ set of samples,
therefore when the total number of samples is not divisible by the number of workers,
this sampler produces different number of samples on different workers.
"""

def __init__(self, size: int):
"""
Args:
size (int): the total number of data of the underlying dataset to sample from
"""
self._size = size
assert size > 0
self._rank = comm.get_rank()
self._world_size = comm.get_world_size()
self._local_indices = self._get_local_indices(size, self._world_size, self._rank)

@staticmethod
def _get_local_indices(total_size, world_size, rank):
shard_size = total_size // world_size
left = total_size % world_size
shard_sizes = [shard_size + int(r < left) for r in range(world_size)]

begin = sum(shard_sizes[:rank])
end = min(sum(shard_sizes[: rank + 1]), total_size)
if end - begin < max(shard_sizes):
assert begin > 0
begin = begin - 1
return range(begin, end)

def __iter__(self):
yield from self._local_indices

def __len__(self):
return len(self._local_indices)
Loading

0 comments on commit 358e86e

Please sign in to comment.