Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify video_domain_adapter #292

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
7ccd345
update .gitignore
xianyuanliu Jan 20, 2022
d955f73
update .gitignore
xianyuanliu Jan 20, 2022
1cecdf2
change root dir
xianyuanliu Jan 22, 2022
f9d0577
add EPIC100DatasetAccess
xianyuanliu Jan 22, 2022
046ef98
change transform_kind to transform
xianyuanliu Jan 22, 2022
77f1b0f
add NUM_SEGMENTS
xianyuanliu Jan 22, 2022
8a8581b
add INPUT_TYPE
xianyuanliu Jan 22, 2022
23b0e8e
add functions in VideoDatasetAccess for feature vector input
xianyuanliu Jan 22, 2022
f993f8d
add get_class_type
xianyuanliu Jan 22, 2022
60951d4
add CLASS_TYPE
xianyuanliu Jan 22, 2022
76f3e72
change num_classes to dict_num_classes
xianyuanliu Jan 22, 2022
feaf72a
update ClassNetVideo for dual-class task
xianyuanliu Jan 22, 2022
f5bc2b7
update test
xianyuanliu Jan 22, 2022
63c5be9
Merge branch 'main' into add_feature_vector_dataloader
xianyuanliu Jan 22, 2022
f89d8fc
change output folder to tb_logs
xianyuanliu Jan 22, 2022
b845a88
add get_class_type test
xianyuanliu Jan 22, 2022
ef74b72
update test_video_access
xianyuanliu Jan 22, 2022
b43802c
update config
xianyuanliu Jan 22, 2022
ba6f5c5
test bug fixes
xianyuanliu Jan 23, 2022
bdf9cbb
add VideoFeatureRecord in Videos.py & improve doc
xianyuanliu Jan 23, 2022
3ea4678
add epic100 test & bug fixes
xianyuanliu Jan 23, 2022
1540051
test bug fixes
xianyuanliu Jan 23, 2022
de0e6cd
test bug fixes
xianyuanliu Jan 23, 2022
cf1638b
add BaseAdaptTrainerVideo
xianyuanliu Jan 23, 2022
a2b3ce8
bug fixes
xianyuanliu Jan 23, 2022
4470413
add CLASS_TYPE
xianyuanliu Jan 23, 2022
37aeaac
add conditional function for class type
xianyuanliu Jan 23, 2022
a95a185
rename to num_classes
xianyuanliu Feb 7, 2022
ab23896
change root dir
xianyuanliu Feb 7, 2022
40861fc
Update doc
xianyuanliu Feb 7, 2022
dc4b990
Merge branch 'add_feature_vector_dataloader' into simplify_video_doma…
xianyuanliu Feb 7, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add BaseAdaptTrainerVideo
  • Loading branch information
xianyuanliu committed Jan 23, 2022
commit cf1638b2ecb8a9f2b33f3066d31fa4b492fb29c0
223 changes: 155 additions & 68 deletions kale/pipeline/video_domain_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
import kale.predict.losses as losses
from kale.loaddata.video_access import get_image_modality
from kale.pipeline.domain_adapter import (
BaseAdaptTrainer,
BaseMMDLike,
CDANTrainer,
DANNTrainer,
get_aggregated_metrics,
get_aggregated_metrics_from_dict,
get_metrics_from_parameter_dict,
GradReverse,
Expand Down Expand Up @@ -100,7 +102,108 @@ def create_dann_like_video(
raise ValueError(f"Unsupported method: {method}")


class BaseMMDLikeVideo(BaseMMDLike):
class BaseAdaptTrainerVideo(BaseAdaptTrainer):
"""Base class for all domain adaptation architectures on videos. Inherited from BaseAdaptTrainer."""

def train_dataloader(self):
dataloader = self._dataset.get_domain_loaders(split="train", batch_size=self._batch_size)
self._nb_training_batches = len(dataloader)
return dataloader

def val_dataloader(self):
dataloader = self._dataset.get_domain_loaders(split="valid", batch_size=self._batch_size)
return dataloader

def test_dataloader(self):
dataloader = self._dataset.get_domain_loaders(split="test", batch_size=self._batch_size)
# dataloader, target_batch_size = self._dataset.get_domain_loaders(split="test", batch_size=500)
return dataloader

def training_step(self, batch, batch_nb):
# print("tr src{} tgt{}".format(len(batch[0][2]), len(batch[1][2])))

self._update_batch_epoch_factors(batch_nb)

task_loss, adv_loss, log_metrics = self.compute_loss(batch, split_name="train")
if self.current_epoch < self._init_epochs:
loss = task_loss
else:
loss = task_loss + self.lamb_da * adv_loss

log_metrics = get_aggregated_metrics_from_dict(log_metrics)
log_metrics.update(get_metrics_from_parameter_dict(self.get_parameters_watch_list(), loss.device))
log_metrics["train_total_loss"] = loss
log_metrics["train_adv_loss"] = adv_loss
log_metrics["train_task_loss"] = task_loss

for key in log_metrics:
self.log(key, log_metrics[key])

return {"loss": loss}

def validation_epoch_end(self, outputs):
metrics_to_log = self.create_metrics_log("val")
return self._validation_epoch_end(outputs, metrics_to_log)

def test_epoch_end(self, outputs):
metrics_at_test = self.create_metrics_log("test")

# Uncomment to save output to json file for EPIC UDA 2021 challenge.(3/3)
# save_results_to_json(
# self.y_hat, self.y_t_hat, self.s_id, self.tu_id, self.y_hat_noun, self.y_t_hat_noun, self.verb, self.noun
# )

log_dict = get_aggregated_metrics(metrics_at_test, outputs)

for key in log_dict:
self.log(key, log_dict[key], prog_bar=True)

def create_metrics_log(self, split_name):
metrics_to_log = (
"{}_loss".format(split_name),
"{}_task_loss".format(split_name),
"{}_adv_loss".format(split_name),
"{}_source_acc".format(split_name),
"{}_source_top1_acc".format(split_name),
"{}_source_top3_acc".format(split_name),
"{}_target_acc".format(split_name),
"{}_target_top1_acc".format(split_name),
"{}_target_top3_acc".format(split_name),
"{}_source_domain_acc".format(split_name),
"{}_target_domain_acc".format(split_name),
"{}_domain_acc".format(split_name),
)
if self.method.is_mmd_method():
metrics_to_log = metrics_to_log[:-3] + ("{}_mmd".format(split_name),)
if split_name == "test":
metrics_to_log = metrics_to_log[:1] + metrics_to_log[4:]
return metrics_to_log

def get_loss_log_metrics(self, split_name, y_hat, y_t_hat, y_s, y_tu, dok):
"""Get the loss, top-k accuracy and metrics for a given split."""

loss_cls, ok_src = losses.cross_entropy_logits(y_hat, y_s)
_, ok_tgt = losses.cross_entropy_logits(y_t_hat, y_tu)
prec1_src, prec3_src = losses.topk_accuracy(y_hat, y_s, topk=(1, 3))
prec1_tgt, prec3_tgt = losses.topk_accuracy(y_t_hat, y_tu, topk=(1, 3))
task_loss = loss_cls

log_metrics = {
f"{split_name}_source_acc": ok_src,
f"{split_name}_target_acc": ok_tgt,
f"{split_name}_source_top1_acc": prec1_src,
f"{split_name}_source_top3_acc": prec3_src,
f"{split_name}_target_top1_acc": prec1_tgt,
f"{split_name}_target_top3_acc": prec3_tgt,
}
if self.method.is_mmd_method():
log_metrics.update({f"{split_name}_mmd": dok})
else:
log_metrics.update({f"{split_name}_domain_acc": dok})
return task_loss, log_metrics


class BaseMMDLikeVideo(BaseAdaptTrainerVideo, BaseMMDLike):
"""Common API for MME-based domain adaptation on video data: DAN, JAN"""

def __init__(
Expand Down Expand Up @@ -152,15 +255,8 @@ def compute_loss(self, batch, split_name="valid"):
# print('rgb_s:{}, flow_s:{}, rgb_f:{}, flow_f:{}'.format(y_s, y_s_flow, y_tu, y_tu_flow))
# print('equal: {}/{}'.format(torch.all(torch.eq(y_s, y_s_flow)), torch.all(torch.eq(y_tu, y_tu_flow))))

# ok is abbreviation for (all) correct
loss_cls, ok_src = losses.cross_entropy_logits(y_hat[0], y_s)
_, ok_tgt = losses.cross_entropy_logits(y_t_hat[0], y_tu)
task_loss = loss_cls
log_metrics = {
f"{split_name}_source_acc": ok_src,
f"{split_name}_target_acc": ok_tgt,
f"{split_name}_domain_acc": mmd,
}
task_loss, log_metrics = self.get_loss_log_metrics(split_name, y_hat, y_t_hat, y_s, y_tu, mmd)

return task_loss, mmd, log_metrics


Expand Down Expand Up @@ -218,7 +314,7 @@ def _compute_mmd(self, phi_s, phi_t, y_hat, y_t_hat):
return losses.compute_mmd_loss(joint_kernels, batch_size)


class DANNTrainerVideo(DANNTrainer):
class DANNTrainerVideo(BaseAdaptTrainerVideo, DANNTrainer):
"""This is an implementation of DANN for video data."""

def __init__(
Expand Down Expand Up @@ -300,43 +396,27 @@ def compute_loss(self, batch, split_name="valid"):
loss_dmn_tgt, dok_tgt = losses.cross_entropy_logits(d_t_hat, torch.ones(batch_size))
dok = torch.cat((dok_src, dok_tgt))

loss_cls, ok_src = losses.cross_entropy_logits(y_hat[0], y_s)
_, ok_tgt = losses.cross_entropy_logits(y_t_hat[0], y_tu)
# loss_cls, ok_src = losses.cross_entropy_logits(y_hat[0], y_s)
# _, ok_tgt = losses.cross_entropy_logits(y_t_hat[0], y_tu)
# adv_loss = loss_dmn_src + loss_dmn_tgt # adv_loss = src + tgt
# task_loss = loss_cls
#
# log_metrics = {
# f"{split_name}_source_acc": ok_src,
# f"{split_name}_target_acc": ok_tgt,
# f"{split_name}_domain_acc": dok,
# f"{split_name}_source_domain_acc": dok_src,
# f"{split_name}_target_domain_acc": dok_tgt,
# }

task_loss, log_metrics = self.get_loss_log_metrics(split_name, y_hat, y_t_hat, y_s, y_tu, dok)
adv_loss = loss_dmn_src + loss_dmn_tgt # adv_loss = src + tgt
task_loss = loss_cls

log_metrics = {
f"{split_name}_source_acc": ok_src,
f"{split_name}_target_acc": ok_tgt,
f"{split_name}_domain_acc": dok,
f"{split_name}_source_domain_acc": dok_src,
f"{split_name}_target_domain_acc": dok_tgt,
}
log_metrics.update({f"{split_name}_source_domain_acc": dok_src, f"{split_name}_target_domain_acc": dok_tgt})

return task_loss, adv_loss, log_metrics

def training_step(self, batch, batch_nb):
self._update_batch_epoch_factors(batch_nb)

task_loss, adv_loss, log_metrics = self.compute_loss(batch, split_name="train")
if self.current_epoch < self._init_epochs:
loss = task_loss
else:
loss = task_loss + self.lamb_da * adv_loss

log_metrics = get_aggregated_metrics_from_dict(log_metrics)
log_metrics.update(get_metrics_from_parameter_dict(self.get_parameters_watch_list(), loss.device))
log_metrics["train_total_loss"] = loss
log_metrics["train_adv_loss"] = adv_loss
log_metrics["train_task_loss"] = task_loss

for key in log_metrics:
self.log(key, log_metrics[key])

return {"loss": loss}


class CDANTrainerVideo(CDANTrainer):
class CDANTrainerVideo(BaseAdaptTrainerVideo, CDANTrainer):
"""This is an implementation of CDAN for video data."""

def __init__(
Expand Down Expand Up @@ -463,23 +543,26 @@ def compute_loss(self, batch, split_name="valid"):
loss_dmn_tgt, dok_tgt = losses.cross_entropy_logits(d_t_hat, torch.ones(batch_size))
dok = torch.cat((dok_src, dok_tgt))

loss_cls, ok_src = losses.cross_entropy_logits(y_hat[0], y_s)
_, ok_tgt = losses.cross_entropy_logits(y_t_hat[0], y_tu)
# loss_cls, ok_src = losses.cross_entropy_logits(y_hat[0], y_s)
# _, ok_tgt = losses.cross_entropy_logits(y_t_hat[0], y_tu)
# adv_loss = loss_dmn_src + loss_dmn_tgt # adv_loss = src + tgt
# task_loss = loss_cls
#
# log_metrics = {
# f"{split_name}_source_acc": ok_src,
# f"{split_name}_target_acc": ok_tgt,
# f"{split_name}_domain_acc": dok,
# f"{split_name}_source_domain_acc": dok_src,
# f"{split_name}_target_domain_acc": dok_tgt,
# }

task_loss, log_metrics = self.get_loss_log_metrics(split_name, y_hat, y_t_hat, y_s, y_tu, dok)
adv_loss = loss_dmn_src + loss_dmn_tgt # adv_loss = src + tgt
task_loss = loss_cls

log_metrics = {
f"{split_name}_source_acc": ok_src,
f"{split_name}_target_acc": ok_tgt,
f"{split_name}_domain_acc": dok,
f"{split_name}_source_domain_acc": dok_src,
f"{split_name}_target_domain_acc": dok_tgt,
}

log_metrics.update({f"{split_name}_source_domain_acc": dok_src, f"{split_name}_target_domain_acc": dok_tgt})
return task_loss, adv_loss, log_metrics


class WDGRLTrainerVideo(WDGRLTrainer):
class WDGRLTrainerVideo(BaseAdaptTrainerVideo, WDGRLTrainer):
"""This is an implementation of WDGRL for video data."""

def __init__(
Expand Down Expand Up @@ -572,19 +655,23 @@ def compute_loss(self, batch, split_name="valid"):
wasserstein_distance = d_hat.mean() - (1 + self._beta_ratio) * d_t_hat.mean()
dok = torch.cat((dok_src, dok_tgt))

loss_cls, ok_src = losses.cross_entropy_logits(y_hat[0], y_s)
_, ok_tgt = losses.cross_entropy_logits(y_t_hat[0], y_tu)
# loss_cls, ok_src = losses.cross_entropy_logits(y_hat[0], y_s)
# _, ok_tgt = losses.cross_entropy_logits(y_t_hat[0], y_tu)
# adv_loss = wasserstein_distance
# task_loss = loss_cls
#
# log_metrics = {
# f"{split_name}_source_acc": ok_src,
# f"{split_name}_target_acc": ok_tgt,
# f"{split_name}_domain_acc": dok,
# f"{split_name}_source_domain_acc": dok_src,
# f"{split_name}_target_domain_acc": dok_tgt,
# f"{split_name}_wasserstein_dist": wasserstein_distance,
# }

task_loss, log_metrics = self.get_loss_log_metrics(split_name, y_hat, y_t_hat, y_s, y_tu, dok)
adv_loss = wasserstein_distance
task_loss = loss_cls

log_metrics = {
f"{split_name}_source_acc": ok_src,
f"{split_name}_target_acc": ok_tgt,
f"{split_name}_domain_acc": dok,
f"{split_name}_source_domain_acc": dok_src,
f"{split_name}_target_domain_acc": dok_tgt,
f"{split_name}_wasserstein_dist": wasserstein_distance,
}
log_metrics.update({f"{split_name}_source_domain_acc": dok_src, f"{split_name}_target_domain_acc": dok_tgt})
return task_loss, adv_loss, log_metrics

def configure_optimizers(self):
Expand Down