Skip to content

Commit

Permalink
Merge branch 'add_feature_vector_dataloader' into simplify_video_doma…
Browse files Browse the repository at this point in the history
…in_adapter
  • Loading branch information
xianyuanliu committed Feb 7, 2022
2 parents 40861fc + ab23896 commit dc4b990
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion examples/action_dann_lightn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# Dataset
# -----------------------------------------------------------------------------
_C.DATASET = CN()
_C.DATASET.ROOT = "J:/Datasets/EgoAction/" # "/shared/tale2/Shared"
_C.DATASET.ROOT = "F:/Datasets/EgoAction/" # "/shared/tale2/Shared"
_C.DATASET.SOURCE = "EPIC" # dataset options=["EPIC", "GTEA", "ADL", "KITCHEN"]
_C.DATASET.SRC_TRAINLIST = "epic_D1_train.pkl"
_C.DATASET.SRC_TESTLIST = "epic_D1_test.pkl"
Expand Down
4 changes: 2 additions & 2 deletions examples/action_dann_lightn/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def main():

# ---- setup dataset ----
seed = cfg.SOLVER.SEED
source, target, dict_num_classes = VideoDataset.get_source_target(
source, target, num_classes = VideoDataset.get_source_target(
VideoDataset(cfg.DATASET.SOURCE.upper()), VideoDataset(cfg.DATASET.TARGET.upper()), seed, cfg
)
dataset = VideoMultiDomainDatasets(
Expand All @@ -69,7 +69,7 @@ def main():
set_seed(seed) # seed_everything in pytorch_lightning did not set torch.backends.cudnn
print(f"==> Building model for seed {seed} ......")
# ---- setup model and logger ----
model, train_params = get_model(cfg, dataset, dict_num_classes)
model, train_params = get_model(cfg, dataset, num_classes)
tb_logger = pl_loggers.TensorBoardLogger(cfg.OUTPUT.TB_DIR, name="seed{}".format(seed))
checkpoint_callback = ModelCheckpoint(
# dirpath=full_checkpoint_dir,
Expand Down
10 changes: 5 additions & 5 deletions examples/action_dann_lightn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,22 +55,22 @@ def get_config(cfg):


# Based on https://github.com/criteo-research/pytorch-ada/blob/master/adalib/ada/utils/experimentation.py
def get_model(cfg, dataset, dict_num_classes):
def get_model(cfg, dataset, num_classes):
"""
Builds and returns a model and associated hyper parameters according to the config object passed.
Args:
cfg: A YACS config object.
dataset: A multi domain dataset consisting of source and target datasets.
dict_num_classes (dict): The dictionary of class number for specific dataset.
num_classes (dict): The dictionary of class number for specific dataset.
"""

# setup feature extractor
feature_network, class_feature_dim, domain_feature_dim = get_video_feat_extractor(
cfg.MODEL.METHOD.upper(), cfg.DATASET.IMAGE_MODALITY, cfg.MODEL.ATTENTION, dict_num_classes
cfg.MODEL.METHOD.upper(), cfg.DATASET.IMAGE_MODALITY, cfg.MODEL.ATTENTION, num_classes
)
# setup classifier
classifier_network = ClassNetVideo(input_size=class_feature_dim, dict_n_class=dict_num_classes)
classifier_network = ClassNetVideo(input_size=class_feature_dim, dict_n_class=num_classes)

config_params = get_config(cfg)
train_params = config_params["train_params"]
Expand Down Expand Up @@ -100,7 +100,7 @@ def get_model(cfg, dataset, dict_num_classes):
if cfg.DAN.USERANDOM:
critic_input_size = cfg.DAN.RANDOM_DIM
else:
critic_input_size = domain_feature_dim * dict_num_classes["verb"]
critic_input_size = domain_feature_dim * num_classes["verb"]
critic_network = DomainNetVideo(input_size=critic_input_size)

if cfg.DAN.METHOD == "CDAN":
Expand Down
6 changes: 3 additions & 3 deletions kale/embed/video_feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from kale.loaddata.video_access import get_image_modality


def get_video_feat_extractor(model_name, image_modality, attention, dict_num_classes):
def get_video_feat_extractor(model_name, image_modality, attention, num_classes):
"""
Get the feature extractor w/o the pre-trained model and SELayers. The pre-trained models are saved in the path
``$XDG_CACHE_HOME/torch/hub/checkpoints/``. For Linux, default path is ``~/.cache/torch/hub/checkpoints/``.
Expand All @@ -26,7 +26,7 @@ def get_video_feat_extractor(model_name, image_modality, attention, dict_num_cla
model_name (string): The name of the feature extractor. (Choices=["I3D", "R3D_18", "R2PLUS1D_18", "MC3_18"])
image_modality (string): Image type. (Choices=["rgb", "flow", "joint"])
attention (string): The attention type. (Choices=["SELayerC", "SELayerT", "SELayerCoC", "SELayerMC", "SELayerCT", "SELayerTC", "SELayerMAC"])
dict_num_classes (dict): The class number of specific dataset. (Default: No use)
num_classes (dict): The class number of specific dataset.
Returns:
feature_network (dictionary): The network to extract features.
Expand All @@ -37,7 +37,7 @@ def get_video_feat_extractor(model_name, image_modality, attention, dict_num_cla

rgb, flow = get_image_modality(image_modality)
# only use verb class when input is image.
num_classes = dict_num_classes["verb"]
num_classes = num_classes["verb"]

attention_list = ["SELayerC", "SELayerT", "SELayerCoC", "SELayerMC", "SELayerCT", "SELayerTC", "SELayerMAC"]
model_list = ["I3D", "R3D_18", "MC3_18", "R2PLUS1D_18"]
Expand Down
6 changes: 3 additions & 3 deletions tests/pipeline/test_video_domain_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_video_domain_adapter(source_cfg, target_cfg, image_modality, da_method,
)

# build dataset
source, target, dict_num_classes = VideoDataset.get_source_target(
source, target, num_classes = VideoDataset.get_source_target(
VideoDataset(cfg.DATASET.SOURCE.upper()), VideoDataset(cfg.DATASET.TARGET.upper()), seed, cfg
)

Expand All @@ -119,7 +119,7 @@ def test_video_domain_adapter(source_cfg, target_cfg, image_modality, da_method,
feature_network = {"rgb": VideoBoringModel(3), "flow": VideoBoringModel(2)}

# setup classifier
classifier_network = ClassNetVideo(input_size=class_feature_dim, dict_n_class=dict_num_classes)
classifier_network = ClassNetVideo(input_size=class_feature_dim, dict_n_class=num_classes)
train_params = testing_training_cfg["train_params"]
method_params = {}
method = domain_adapter.Method(da_method)
Expand All @@ -144,7 +144,7 @@ def test_video_domain_adapter(source_cfg, target_cfg, image_modality, da_method,
if cfg.DAN.USERANDOM:
critic_input_size = 1024
else:
critic_input_size = domain_feature_dim * dict_num_classes["verb"]
critic_input_size = domain_feature_dim * num_classes["verb"]
critic_network = DomainNetVideo(input_size=critic_input_size)

if da_method == "CDAN":
Expand Down

0 comments on commit dc4b990

Please sign in to comment.