-
-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplify video_domain_adapter #292
base: main
Are you sure you want to change the base?
Changes from 27 commits
7ccd345
d955f73
1cecdf2
f9d0577
046ef98
77f1b0f
8a8581b
23b0e8e
f993f8d
60951d4
76f3e72
feaf72a
f5bc2b7
63c5be9
f89d8fc
b845a88
ef74b72
b43802c
ba6f5c5
bdf9cbb
3ea4678
1540051
de0e6cd
cf1638b
a2b3ce8
4470413
37aeaac
a95a185
ab23896
40861fc
dc4b990
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,32 +48,36 @@ def get_config(cfg): | |
"target": cfg.DATASET.TARGET, | ||
"size_type": cfg.DATASET.SIZE_TYPE, | ||
"weight_type": cfg.DATASET.WEIGHT_TYPE, | ||
"class_type": cfg.DATASET.CLASS_TYPE, | ||
}, | ||
} | ||
return config_params | ||
|
||
|
||
# Based on https://github.com/criteo-research/pytorch-ada/blob/master/adalib/ada/utils/experimentation.py | ||
def get_model(cfg, dataset, num_classes): | ||
def get_model(cfg, dataset, dict_num_classes): | ||
""" | ||
Builds and returns a model and associated hyper parameters according to the config object passed. | ||
|
||
Args: | ||
cfg: A YACS config object. | ||
dataset: A multi domain dataset consisting of source and target datasets. | ||
num_classes: The class number of specific dataset. | ||
dict_num_classes (dict): The dictionary of class number for specific dataset. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it better to implement this as a class (e.g., https://github.com/pykale/pykale/blob/main/kale/pipeline/domain_adapter.py#L81). We can discuss if you need. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A function returning boolean variables is used to control image_modality and class_type. A class may be a better choice, but I have no idea about this. We can talk. |
||
""" | ||
|
||
# setup feature extractor | ||
feature_network, class_feature_dim, domain_feature_dim = get_video_feat_extractor( | ||
cfg.MODEL.METHOD.upper(), cfg.DATASET.IMAGE_MODALITY, cfg.MODEL.ATTENTION, num_classes | ||
cfg.MODEL.METHOD.upper(), cfg.DATASET.IMAGE_MODALITY, cfg.MODEL.ATTENTION, dict_num_classes | ||
) | ||
# setup classifier | ||
classifier_network = ClassNetVideo(input_size=class_feature_dim, n_class=num_classes) | ||
classifier_network = ClassNetVideo(input_size=class_feature_dim, dict_n_class=dict_num_classes) | ||
|
||
config_params = get_config(cfg) | ||
train_params = config_params["train_params"] | ||
train_params_local = deepcopy(train_params) | ||
data_params = config_params["data_params"] | ||
data_params_local = deepcopy(data_params) | ||
class_type = data_params_local["class_type"] | ||
method_params = {} | ||
|
||
method = domain_adapter.Method(cfg.DAN.METHOD) | ||
|
@@ -85,6 +89,7 @@ def get_model(cfg, dataset, num_classes): | |
image_modality=cfg.DATASET.IMAGE_MODALITY, | ||
feature_extractor=feature_network, | ||
task_classifier=classifier_network, | ||
class_type=class_type, | ||
**method_params, | ||
**train_params_local, | ||
) | ||
|
@@ -95,7 +100,7 @@ def get_model(cfg, dataset, num_classes): | |
if cfg.DAN.USERANDOM: | ||
critic_input_size = cfg.DAN.RANDOM_DIM | ||
else: | ||
critic_input_size = domain_feature_dim * num_classes | ||
critic_input_size = domain_feature_dim * dict_num_classes["verb"] | ||
critic_network = DomainNetVideo(input_size=critic_input_size) | ||
|
||
if cfg.DAN.METHOD == "CDAN": | ||
|
@@ -109,6 +114,7 @@ def get_model(cfg, dataset, num_classes): | |
feature_extractor=feature_network, | ||
task_classifier=classifier_network, | ||
critic=critic_network, | ||
class_type=class_type, | ||
**method_params, | ||
**train_params_local, | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same problem as in PR #291.