Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR1: Add mkmmd loss #168

Open
wants to merge 185 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 180 commits
Commits
Show all changes
185 commits
Select commit Hold shift + click to select a range
fb43681
Add mkmmd loss
sanaAyrml Jan 23, 2024
9c9d70a
code cleaning
sanaAyrml Jan 23, 2024
81f37e3
code cleaning
sanaAyrml Jan 23, 2024
c3278a4
Update poetry
sanaAyrml Jan 23, 2024
538578c
Merge branch 'main' into mkmdd_loss
sanaAyrml Jan 23, 2024
cb5013b
Cleaning code
sanaAyrml Jan 23, 2024
374d0aa
Add fedisic experiments
sanaAyrml Jan 25, 2024
a578838
update time
sanaAyrml Jan 25, 2024
3b58481
update
sanaAyrml Jan 25, 2024
f42d777
edit
sanaAyrml Jan 25, 2024
81257ae
Fix mkmmd
sanaAyrml Jan 26, 2024
d903f50
update mkmmd fedisic
sanaAyrml Jan 26, 2024
96f1914
update mkmmd
sanaAyrml Jan 26, 2024
39c1e8f
Proposed changes to the mkmmd loss function
emersodb Jan 29, 2024
44d091c
Small fix to the documentation of the call for get_best_vertex_for_ob…
emersodb Jan 30, 2024
083174d
Update beta optimization
sanaAyrml Feb 2, 2024
66c1735
Add mkmmd to moon and fenda
sanaAyrml Feb 5, 2024
7b11f84
Update contrastive loss
sanaAyrml Feb 5, 2024
95ca579
Add beta_with_extreme_kernel_base_values computation
sanaAyrml Feb 5, 2024
d18a926
Merge branch 'main' into mkmdd_loss
sanaAyrml Feb 5, 2024
50bcf59
resolve conflicts of main merge
sanaAyrml Feb 5, 2024
6149021
fix tests
sanaAyrml Feb 5, 2024
598b98a
Fix main merge conflicts
sanaAyrml Feb 5, 2024
0b3368d
fix tests
sanaAyrml Feb 5, 2024
43bd85c
edit tests
sanaAyrml Feb 5, 2024
6efaf35
change tests
sanaAyrml Feb 5, 2024
2e285ff
fix tests
sanaAyrml Feb 5, 2024
be52d0a
fix tests
sanaAyrml Feb 5, 2024
299e939
Add moon mkmmd example
sanaAyrml Feb 5, 2024
15c25e4
Merge branch 'main' into mkmdd_loss
sanaAyrml Feb 8, 2024
616d727
debug moon client
sanaAyrml Feb 8, 2024
fd9529d
update moon
sanaAyrml Feb 8, 2024
77b5f54
update moon
sanaAyrml Feb 8, 2024
96ab534
Merge branch 'main' into mkmdd_loss
sanaAyrml Feb 9, 2024
637a825
Merge branch 'main' into mkmdd_loss
sanaAyrml Feb 9, 2024
2853cae
Merge branch 'sa-separate' into mkmdd_loss
sanaAyrml Feb 12, 2024
1ff8efa
Merge branch 'sa-separate' into mkmdd_loss
sanaAyrml Feb 12, 2024
5553fe2
Merge branch 'mkmdd_loss' of https://github.com/VectorInstitute/FL4He…
sanaAyrml Feb 16, 2024
19895c5
Merge branch 'main' into mkmdd_loss
emersodb Feb 16, 2024
67b6c88
Merge branch 'main' of https://github.com/VectorInstitute/FL4Health
sanaAyrml Feb 20, 2024
62e1536
Merge branch 'main' into mkmdd_loss
sanaAyrml Feb 20, 2024
db8396e
adding qpth to poetry
sanaAyrml Feb 20, 2024
ccf1522
Apply David's comments
sanaAyrml Feb 20, 2024
0f45556
update fenda test
sanaAyrml Feb 20, 2024
112e47e
add interval beta updates
sanaAyrml Feb 23, 2024
adca64c
Update mkmmd implementation for moon and fenda
sanaAyrml Feb 24, 2024
f612392
print optimized betas
sanaAyrml Feb 24, 2024
5289678
debug mkmmd
sanaAyrml Feb 28, 2024
cf07955
update mkmmd
sanaAyrml Feb 28, 2024
82d861a
more debug
sanaAyrml Feb 28, 2024
e5a7c97
Check again
sanaAyrml Mar 1, 2024
f0d2445
checl qp function
sanaAyrml Mar 1, 2024
9d83f43
Merge branch 'main' into mkmdd_loss
sanaAyrml Mar 12, 2024
6174732
fixing poetry lock
sanaAyrml Mar 12, 2024
f22b555
update lock file
sanaAyrml Mar 12, 2024
0590c0e
add moon mkmmd to fed ixi
sanaAyrml Mar 12, 2024
dcb1f21
Fail to add qpth in the poetry
sanaAyrml Mar 12, 2024
e2d4ce6
fixing poetry
sanaAyrml Mar 12, 2024
6011c5d
add distributed training for fl
sanaAyrml Mar 14, 2024
b82abea
fix client and server files
sanaAyrml Mar 14, 2024
0b42d78
edit server
sanaAyrml Mar 14, 2024
d373138
check fl cluster
sanaAyrml Mar 14, 2024
068e009
add better logging path
sanaAyrml Mar 14, 2024
0b6c65f
update client and server
sanaAyrml Mar 14, 2024
89c2200
maybe that is the bug
sanaAyrml Mar 14, 2024
2066889
update server
sanaAyrml Mar 14, 2024
0b67c58
add run making dir
sanaAyrml Mar 14, 2024
0bc3d5d
update experimemt dir making
sanaAyrml Mar 14, 2024
d0e0541
minor commit
sanaAyrml Mar 15, 2024
56303ef
Merge branch 'main' into mkmdd_loss
sanaAyrml Mar 24, 2024
8b5f479
Delete extra files
sanaAyrml Mar 24, 2024
bff9b1e
fix pytests and add qpth to poetry
sanaAyrml Mar 24, 2024
b579a6c
FIx aggregate loss in utils
sanaAyrml Mar 24, 2024
afea507
Fix precommit errors
sanaAyrml Mar 24, 2024
b55c3cd
Fix tests and mkmmd
sanaAyrml Mar 24, 2024
c81dbfa
add mkmmd for ditto
sanaAyrml Mar 24, 2024
71d80b8
add ditto mkmmd experimetns
sanaAyrml Mar 24, 2024
2a53c58
solve infeasibility issue
sanaAyrml Mar 25, 2024
6802811
change range
sanaAyrml Mar 25, 2024
b7e4088
fix typo and pytests
sanaAyrml Mar 25, 2024
e492c16
fix fedisic experiment
sanaAyrml Mar 25, 2024
aba0bee
Update moon model assertion for ditto
sanaAyrml Mar 25, 2024
a994b8b
update ditto mkmmd
sanaAyrml Mar 25, 2024
4dd7fc2
update errore
sanaAyrml Mar 25, 2024
408e6d2
Update beta optimization
sanaAyrml Mar 25, 2024
515d79f
fix model saving
sanaAyrml Mar 25, 2024
a63d079
add l2 feature norms
sanaAyrml Mar 26, 2024
638e69e
Update pytests for clients
sanaAyrml Mar 26, 2024
d204c2c
Update syntax errors
sanaAyrml Mar 26, 2024
87b1fef
Update ditto mkmmd client
sanaAyrml Mar 26, 2024
db443e6
fix ditto mkmmd
sanaAyrml Mar 26, 2024
3c2acb4
Update script run
sanaAyrml Mar 26, 2024
75f5bbe
delete extra log
sanaAyrml Mar 26, 2024
ae387b5
update ditto mkmmd loss
sanaAyrml Mar 27, 2024
df1a780
fix mkmmd loss computing
sanaAyrml Mar 27, 2024
3793143
Merge branch 'main' into mkmdd_loss
sanaAyrml Mar 27, 2024
459aa2c
update norm computatio
sanaAyrml Mar 27, 2024
3e8dacc
fix l2 norm
sanaAyrml Mar 27, 2024
610c472
Update fenda mkmmd client
sanaAyrml Mar 27, 2024
75373df
Update research files
sanaAyrml Mar 27, 2024
fa036e2
Update doc strings
sanaAyrml Apr 9, 2024
f825301
fix pre commit issues
sanaAyrml Apr 9, 2024
90dcebf
fix static code checks failing
sanaAyrml Apr 10, 2024
1c6aec3
Merge branch 'main' into mkmdd_loss
sanaAyrml Apr 16, 2024
bb6a322
Updating lock file
sanaAyrml Apr 17, 2024
70336d5
Detach ditto mkmmd from ditto
sanaAyrml Apr 17, 2024
055d903
Fix ditto mkmmd client comments
sanaAyrml Apr 17, 2024
eeb3d55
separate mkmmd additions to moon
sanaAyrml Apr 18, 2024
d61a77f
Separate fenda mkmmd with fenda
sanaAyrml Apr 18, 2024
25b3974
Address some more comments
sanaAyrml Apr 18, 2024
ac594a2
Fix fenda and moon mkmmd doc strings
sanaAyrml Apr 18, 2024
ad9b31a
Merge branch 'main' into mkmdd_loss
sanaAyrml Apr 18, 2024
eae0065
Update mkmmd tests
sanaAyrml Apr 18, 2024
aeb67dd
Update research files
sanaAyrml Apr 18, 2024
2b41f70
update research clients
sanaAyrml Apr 18, 2024
68632c0
fix moon_mkmmd
sanaAyrml Apr 18, 2024
dea3b57
Add mr_mtl_mkkmmd_client.py
sanaAyrml Apr 19, 2024
fef1303
Add mr_mtl_research files
sanaAyrml Apr 19, 2024
7e37b72
Add normalization to mkmmd
sanaAyrml Apr 19, 2024
606ec7f
add local beta optimization option
sanaAyrml Apr 19, 2024
99ef911
Add feature extractor base model
sanaAyrml Apr 23, 2024
b12fe77
fix static code failure
sanaAyrml Apr 23, 2024
05d3610
fix bug in feature extractor
sanaAyrml Apr 23, 2024
d012dad
update mrmtl model implementation
sanaAyrml Apr 23, 2024
db2ac05
change feature extractor name
sanaAyrml Apr 23, 2024
3d91119
Add feature extractor buffer
sanaAyrml Apr 24, 2024
da1b61b
shift mkmmd loss computation to one front
sanaAyrml Apr 24, 2024
e6c17af
fix fed isic ditto mkmmd experiment
sanaAyrml Apr 24, 2024
0f9a22a
fix mrmtl example
sanaAyrml Apr 25, 2024
cc72a88
add fedheartdisease tests
sanaAyrml Apr 25, 2024
3e8c354
fix mrmtl bug
sanaAyrml Apr 25, 2024
8444667
fix mr_mtl bug
sanaAyrml Apr 25, 2024
18767c9
fix moon runs
sanaAyrml Apr 25, 2024
cecaf60
delete extra fed heart files
sanaAyrml Apr 25, 2024
3234c06
delete test
sanaAyrml Apr 25, 2024
fb91d4b
fix experiments
sanaAyrml Apr 25, 2024
ebbec10
add some extra logging
sanaAyrml Apr 25, 2024
bef1b4a
fix new slrm files
sanaAyrml Apr 26, 2024
5c9efff
fix bash file
sanaAyrml Apr 26, 2024
ce46dfb
Fix feature extractor
sanaAyrml Apr 26, 2024
2395e12
Fix pytest failing
sanaAyrml Apr 26, 2024
3850739
check smoke test failure
sanaAyrml Apr 27, 2024
f56390d
check smoke test failure
sanaAyrml Apr 27, 2024
2f7dbc3
check smoke test failure
sanaAyrml Apr 27, 2024
f5cfdd9
check smoke test failure
sanaAyrml Apr 27, 2024
f670b02
Fix smoke tests
sanaAyrml Apr 30, 2024
dffa1ad
add extra logging
sanaAyrml May 1, 2024
c284473
Fix logging
sanaAyrml May 1, 2024
9ca1704
Fix beta update
sanaAyrml May 3, 2024
2cede39
Merge branch 'main' into mkmdd_loss
sanaAyrml May 9, 2024
afd69df
Merge main
sanaAyrml May 9, 2024
45ca9ba
add deep_mmd_client
sanaAyrml May 27, 2024
701d710
Merge branch 'main' into mkmdd_loss
sanaAyrml May 27, 2024
c5ab04d
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 27, 2024
5894dd8
fix precommit changes
sanaAyrml May 27, 2024
2952364
Merge branch 'mkmdd_loss' of https://github.com/VectorInstitute/FL4He…
sanaAyrml May 27, 2024
c7a3da7
change deepmmd name
sanaAyrml May 27, 2024
333d9e9
fix deep mmd loss implementation
sanaAyrml May 27, 2024
e1f3de1
fix deep mmd
sanaAyrml May 27, 2024
80f3c5e
fix deep mmd loss
sanaAyrml May 27, 2024
27447fb
Fix deep mmd codes
sanaAyrml May 29, 2024
cc4579b
fix deep mmd loss
sanaAyrml May 29, 2024
b1bd794
add file for fedisic
sanaAyrml May 29, 2024
5654a8e
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 29, 2024
1125de2
fix deep mmd clients
sanaAyrml May 30, 2024
a0bd92b
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 30, 2024
9b5dfc3
Merge branch 'mkmdd_loss' into add_mkmmd_loss
sanaAyrml Jun 6, 2024
6cbcc25
Delete extra files
sanaAyrml Jun 6, 2024
255ab52
Merge branch 'main' into add_mkmmd_loss
sanaAyrml Jun 6, 2024
970ccea
delete extra changes
sanaAyrml Jun 6, 2024
f69bba1
locking poetry
sanaAyrml Jun 7, 2024
f8854f2
fixing tests
sanaAyrml Jun 7, 2024
cc05d6f
fix privacy estimation
sanaAyrml Jun 7, 2024
f14dad2
try to pass test
sanaAyrml Jun 7, 2024
ab064f2
Fix pytest failing
sanaAyrml Jun 7, 2024
6fe5280
Seeing if I can pip list within this test.
emersodb Jun 12, 2024
9d36c0d
Merge branch 'main' into add_mkmmd_loss
emersodb Jun 12, 2024
a49aa70
Reverting changes because github is being a pain.
emersodb Jun 12, 2024
d5381c3
Trying to dump pip in a test again.
emersodb Jun 12, 2024
e3a997d
Putting code back to the way it was.
emersodb Jun 12, 2024
02bf3ef
Apply david's comments changes
sanaAyrml Jun 20, 2024
2c84089
Merge branch 'main' into add_mkmmd_loss
sanaAyrml Jun 27, 2024
31aef6b
Add tests for feature extractor buffer
sanaAyrml Jun 28, 2024
bf34662
Merge branch 'main' into add_mkmmd_loss
sanaAyrml Jun 28, 2024
14df7da
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion fl4health/clients/basic_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,8 @@ def update_after_step(self, step: int) -> None:
"""
Hook method called after local train step on client. step is an integer that represents
the local training step that was most recently completed. For example, used by the APFL
method to update the alpha value after a training a step.
method to update the alpha value after a training a step. Also used by the MOON, FENDA
and Ditto to update optimized beta value for MK-MMD loss after n steps.
emersodb marked this conversation as resolved.
Show resolved Hide resolved

Args:
step (int): The step number in local training that was most recently completed.
Expand Down
2 changes: 0 additions & 2 deletions fl4health/clients/ditto_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,6 @@ def train_step(
preds, features = self.predict(input)

# Compute all relevant losses
# NOTE: features here should be a blank dictionary, as we're not using them
assert len(features) == 0
losses = self.compute_training_loss(preds, features, target)

# Take a step with the global model vanilla loss
Expand Down
274 changes: 274 additions & 0 deletions fl4health/clients/mkmmd_clients/ditto_mkmmd_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
from logging import ERROR, INFO
from pathlib import Path
from typing import Dict, Optional, Sequence, Tuple

import torch
import torch.nn as nn
from flwr.common.logger import log
from flwr.common.typing import Config, Scalar

from fl4health.checkpointing.client_module import CheckpointMode, ClientCheckpointModule
from fl4health.clients.basic_client import TorchInputType
from fl4health.clients.ditto_client import DittoClient
from fl4health.losses.mkmmd_loss import MkMmdLoss
from fl4health.model_bases.feature_extractor_buffer import FeatureExtractorBuffer
from fl4health.utils.losses import LossMeterType
from fl4health.utils.metrics import Metric


class DittoMkmmdClient(DittoClient):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very minor, the acronym capitalization we used for the loss function was MkMmd (vs. Mkmmd) since its MK-MMD as an acronym and we treated the - as a space? We can do either, I don't really care, but maybe we can be consistent? We can also do the same for MR-MTL as either MrMtl or Mrmtl (I like the first a little better looking at it now)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I update these, but MrMtlMkMmd looks so spiky :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it doesn't look great 😂. Still, it follows our convention I guess.

def __init__(
self,
data_path: Path,
metrics: Sequence[Metric],
device: torch.device,
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
checkpointer: Optional[ClientCheckpointModule] = None,
lam: float = 1.0,
mkmmd_loss_weight: float = 10.0,
flatten_feature_extraction_layers: Dict[str, bool] = {},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting a default empty dictionary here is dangerous (because python is dumb), since dictionaries are mutable. Consider the code below and the resulting output

from typing import Dict

class A:
    def __init__(self, d: Dict[str, str] = {}) -> None:
        self.d = d

a1 = A()
a2 = A()

print(a1.d)
print(a2.d)
print("Looks okay so far")

a1.d["Hello"] = "Oh No!"

print(a1.d)
print(a2.d)
print("NOOOOO!")

Output:

{}
{}
Looks okay so far
{'Hello': 'Oh No!'}
{'Hello': 'Oh No!'}
NOOOOO!

For whatever reason, python ties the reference to the default dictionary to all instances of that class. So changes to a1 unexpectedly affect a2

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The standard way to implement a default value for these is to use an Optional[Dict[...]] = None then in the __init__ to do

if flatten_feature_extraction_layers:
    self.flatten_feature_extraction_layers = flatten_feature_extraction_layers
else:
    self.flatten_feature_extraction_layers = {}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nooo that is really scary 😱

feature_l2_norm_weight: float = 0.0,
beta_global_update_interval: int = 20,
) -> None:
"""
This client implements the MK-MMD loss function in the Ditto framework. The MK-MMD loss is a measure of the
distance between the distributions of the features of the local model and init global of each round. The MK-MMD
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...initial global model of each round...

loss is added to the local loss to penalize the local model for drifting away from the global model.

Args:
data_path (Path): path to the data to be used to load the data for client-side training
metrics (Sequence[Metric]): Metrics to be computed based on the labels and predictions of the client model
device (torch.device): Device indicator for where to send the model, batches, labels etc. Often 'cpu' or
'cuda'
loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over
each batch. Defaults to LossMeterType.AVERAGE.
checkpointer (Optional[TorchCheckpointer], optional): Checkpointer to be used for client-side
checkpointing. Defaults to None.
metrics_reporter (Optional[MetricsReporter], optional): A metrics reporter instance to record the metrics
during the execution. Defaults to an instance of MetricsReporter with default init parameters.
lam (float, optional): weight applied to the Ditto drift loss. Defaults to 1.0.
mkmmd_loss_weight (float, optional): weight applied to the MK-MMD loss. Defaults to 10.0.
flatten_feature_extraction_layers (Dict[str, bool], optional): Dictionary of layers to extract features
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe useful just to state that the keys here are the model layer names as extracted from the named_modules?

from them and whether to flatten them. Defaults to {}.
feature_l2_norm_weight (float, optional): weight applied to the L2 norm of the features.
Defaults to 0.0.
beta_global_update_interval (int, optional): interval at which to update the betas for the
emersodb marked this conversation as resolved.
Show resolved Hide resolved
MK-MMD loss. Defaults to 20. If set to -1, the betas will be updated for each individual batch.
If set to 0, the betas will not be updated.
"""
super().__init__(
data_path=data_path,
metrics=metrics,
device=device,
loss_meter_type=loss_meter_type,
checkpointer=checkpointer,
lam=lam,
)
self.mkmmd_loss_weight = mkmmd_loss_weight
if self.mkmmd_loss_weight == 0:
log(
ERROR,
"MK-MMD loss weight is set to 0. As MK-MMD loss will not be computed, ",
"please use vanilla DittoClient instead.",
)

self.feature_l2_norm_weight = feature_l2_norm_weight
self.beta_global_update_interval = beta_global_update_interval
if self.beta_global_update_interval == -1:
log(INFO, "Betas for the MK-MMD loss will be updated for each individual batch.")
elif self.beta_global_update_interval == 0:
log(INFO, "Betas for the MK-MMD loss will not be updated.")
elif self.beta_global_update_interval > 0:
log(INFO, f"Betas for the MK-MMD loss will be updated every {self.beta_global_update_interval} steps.")
else:
raise ValueError("Invalid beta_global_update_interval. It should be either -1, 0 or a positive integer.")
self.flatten_feature_extraction_layers = flatten_feature_extraction_layers
self.mkmmd_losses = {}
for layer in self.flatten_feature_extraction_layers.keys():
self.mkmmd_losses[layer] = MkMmdLoss(
device=self.device, minimize_type_two_error=True, normalize_features=True, layer_name=layer
).to(self.device)

self.init_global_model: nn.Module
self.local_feature_extractor: FeatureExtractorBuffer
self.init_global_feature_extractor: FeatureExtractorBuffer

def setup_client(self, config: Config) -> None:
super().setup_client(config)
self.local_feature_extractor = FeatureExtractorBuffer(
model=self.model,
flatten_feature_extraction_layers=self.flatten_feature_extraction_layers,
)

def update_before_train(self, current_server_round: int) -> None:
super().update_before_train(current_server_round)
assert isinstance(self.global_model, nn.Module)
# Register hooks to extract features from the local model if not already registered
self.local_feature_extractor._maybe_register_hooks()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very minor, but is there any reason we don't register the local hooks in setup_client where we create self.local_feature_extractor since that will only be run once anyway?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see, if the hooks are removed for checkpointing, we need to put them back.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe adding a note here would be good. That is, something like

# Hooks have to be removed to checkpoint the model, so we check if they need to be re-registered each time.

# Clone and freeze the initial weights GLOBAL MODEL. These are used to form the Ditto local
# update penalty term.
self.init_global_model = self.clone_and_freeze_model(self.global_model)
self.init_global_feature_extractor = FeatureExtractorBuffer(
model=self.init_global_model,
flatten_feature_extraction_layers=self.flatten_feature_extraction_layers,
)
# Register hooks to extract features from the init global model if not already registered
self.init_global_feature_extractor._maybe_register_hooks()

def _should_optimize_betas(self, step: int) -> bool:
assert self.beta_global_update_interval is not None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need this assertion, as it's type is just int?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is true.

step_at_interval = (step - 1) % self.beta_global_update_interval == 0
valid_components_present = self.init_global_model is not None
return step_at_interval and valid_components_present

def update_after_step(self, step: int) -> None:
if self.beta_global_update_interval > 0 and self._should_optimize_betas(step):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If beta_global_update_interval is -1 it looks like we'll always skip the optimization of the betas, which, I think would contradict the documentation above?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is true. When beta_global_update_interval is -1, I update it in line 253 as:

            if self.beta_global_update_interval == -1:
                # Update betas for the MK-MMD loss based on computed features during training
                for layer in self.flatten_feature_extraction_layers.keys():
                    self.mkmmd_losses[layer].betas = self.mkmmd_losses[layer].optimize_betas(
                        X=features[layer], Y=features[" ".join(["init_global", layer])], lambda_m=1e-5
                    )

The reason to do so is that I don't want to update buffer with whole data and I only give the computed features of that batch.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I see what you mean there. Maybe we can just add a comment to make that obvious?

assert self.init_global_model is not None
sanaAyrml marked this conversation as resolved.
Show resolved Hide resolved
# Get the feature distribution of the local and init global features with evaluation mode
local_distributions, init_global_distributions = self.update_buffers(self.model, self.init_global_model)
sanaAyrml marked this conversation as resolved.
Show resolved Hide resolved
# Update betas for the MK-MMD loss based on gathered features during training
if self.mkmmd_loss_weight != 0:
for layer in self.flatten_feature_extraction_layers.keys():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than using the keys of self.flatten_feature_extraction_layers here, it seems cleaner to do

for layer, layer_mkmmd_loss in self.mkmmd_losses.items():
    layer_mkmmd_loss.betas = ...

self.mkmmd_losses[layer].betas = self.mkmmd_losses[layer].optimize_betas(
X=local_distributions[layer], Y=init_global_distributions[layer], lambda_m=1e-5
)

return super().update_after_step(step)

def update_buffers(
self, local_model: torch.nn.Module, init_global_model: torch.nn.Module
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
"""Update the feature buffer of the local and global features."""

self.local_feature_extractor.clear_buffers()
self.init_global_feature_extractor.clear_buffers()

self.local_feature_extractor.enable_accumulating_features()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may be misunderstanding the accumulating feature functionality here, it looks like we enable accumulation here, add the features and then turn it off immediately after and there are no other places where we flip these switches. Is there any reason we wouldn't just accumulate by default?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disable it because of memory issues we might get, if we accumulate always. Because model hooks are applied and called whenever model gets an input, if we don't disable it, feature going to keep accumulate.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I think I see. If you don't disable accumulation, there will be accumulation during training as well. I was thinking that you clear the buffers immediately after filling them, so no need to worry about it, but that's not the case.

self.init_global_feature_extractor.enable_accumulating_features()

# Save the initial state of the local model to restore it after the buffer is populated,
# however as init global model is already cloned and frozen, we don't need to save its state.
init_state_local_model = local_model.training

# Set local model to evaluation mode, as we don't want to create a computational graph
# for the local model when populating the local buffer with features to compute optimal
# betas for the MK-MMD loss
local_model.eval()

# Make sure the local model is in evaluation mode before populating the local buffer
assert not local_model.training

# Make sure the init global model is in evaluation mode before populating the global buffer
# as it is already cloned and frozen from the global model
assert not init_global_model.training

with torch.no_grad():
for i, (input, _) in enumerate(self.train_loader):
sanaAyrml marked this conversation as resolved.
Show resolved Hide resolved
input = input.to(self.device)
# Pass the input through the local model to populate the local_feature_extractor buffer
_ = local_model(input)
sanaAyrml marked this conversation as resolved.
Show resolved Hide resolved
# Pass the input through the init global model to populate the local_feature_extractor buffer
_ = init_global_model(input)
local_distributions: Dict[str, torch.Tensor] = self.local_feature_extractor.get_extracted_features()
sanaAyrml marked this conversation as resolved.
Show resolved Hide resolved
init_global_distributions: Dict[str, torch.Tensor] = (
self.init_global_feature_extractor.get_extracted_features()
)
# Restore the initial state of the local model
if init_state_local_model:
local_model.train()

self.local_feature_extractor.disable_accumulating_features()
self.init_global_feature_extractor.disable_accumulating_features()

self.local_feature_extractor.clear_buffers()
self.init_global_feature_extractor.clear_buffers()

return local_distributions, init_global_distributions

def predict(
self,
input: TorchInputType,
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
"""
Computes the predictions for both the GLOBAL and LOCAL models and pack them into the prediction dictionary

Args:
input (Union[torch.Tensor, Dict[str, torch.Tensor]]): Inputs to be fed into both models.

Returns:
Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: A tuple in which the first element
contains predictions indexed by name and the second element contains intermediate activations
index by name.

Raises:
ValueError: Occurs when something other than a tensor or dict of tensors is returned by the model
forward.
"""

# We use features from init_global_model to compute the MK-MMD loss not the global_model
global_preds = self.global_model(input)
local_preds = self.model(input)
features = self.local_feature_extractor.get_extracted_features()
if self.mkmmd_loss_weight != 0:
# Compute the features of the init_global_model
sanaAyrml marked this conversation as resolved.
Show resolved Hide resolved
_ = self.init_global_model(input)
init_global_features = self.init_global_feature_extractor.get_extracted_features()
for key in init_global_features.keys():
features[" ".join(["init_global", key])] = init_global_features[key]

return {"global": global_preds, "local": local_preds}, features

def _maybe_checkpoint(self, loss: float, metrics: Dict[str, Scalar], checkpoint_mode: CheckpointMode) -> None:
# Hooks need to be removed before checkpointing the model
self.local_feature_extractor.remove_hooks()
super()._maybe_checkpoint(loss=loss, metrics=metrics, checkpoint_mode=checkpoint_mode)

def compute_loss_and_additional_losses(
self,
preds: Dict[str, torch.Tensor],
features: Dict[str, torch.Tensor],
target: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Computes the loss and any additional losses given predictions of the model and ground truth data.

Args:
preds (Dict[str, torch.Tensor]): Prediction(s) of the model(s) indexed by name.
features (Dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name.
target (torch.Tensor): Ground truth data to evaluate predictions against.

Returns:
Tuple[torch.Tensor, Dict[str, torch.Tensor]]; A tuple with:
- The tensor for the total loss
- A dictionary with `local_loss`, `global_loss`, `total_loss` and, based on client attributes set
from server config, also `mkmmd_loss`, `feature_l2_norm_loss` keys and their respective calculated
values.
"""
total_loss, additional_losses = super().compute_loss_and_additional_losses(preds, features, target)

if self.mkmmd_loss_weight != 0:
if self.beta_global_update_interval == -1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I sort of see what's happing when beta_global_update_interval is -1. If it's not >0, then we use the whole training set to update the betas at the correct interval. If it's -1 it's updated only using the current batch statistics. I'm okay with leaving this option, but maybe we can be a bit more detailed in the documentation to explain this a bit more?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly, sure I will try to add more documentation.

Copy link
Collaborator

@emersodb emersodb Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can just add it to where you discuss what -1 is above in the __init__ docstring.

# Update betas for the MK-MMD loss based on computed features during training
for layer in self.flatten_feature_extraction_layers.keys():
sanaAyrml marked this conversation as resolved.
Show resolved Hide resolved
self.mkmmd_losses[layer].betas = self.mkmmd_losses[layer].optimize_betas(
X=features[layer], Y=features[" ".join(["init_global", layer])], lambda_m=1e-5
)
# Compute MK-MMD loss
total_mkmmd_loss = torch.tensor(0.0, device=self.device)
for layer in self.flatten_feature_extraction_layers.keys():
sanaAyrml marked this conversation as resolved.
Show resolved Hide resolved
layer_mkmmd_loss = self.mkmmd_losses[layer](
features[layer], features[" ".join(["init_global", layer])]
)
additional_losses["_".join(["mkmmd_loss", layer])] = layer_mkmmd_loss
total_mkmmd_loss += layer_mkmmd_loss
total_loss += self.mkmmd_loss_weight * total_mkmmd_loss
additional_losses["mkmmd_loss_total"] = total_mkmmd_loss
if self.feature_l2_norm_weight:
# Compute the average L2 norm of the features over the batch
feature_l2_norm_loss = torch.linalg.norm(features["features"]) / len(features["features"])
total_loss += self.feature_l2_norm_weight * feature_l2_norm_loss
additional_losses["feature_l2_norm_loss"] = feature_l2_norm_loss

additional_losses["total_loss"] = total_loss

return total_loss, additional_losses
Loading