-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PR1: Add mkmmd loss #168
base: main
Are you sure you want to change the base?
PR1: Add mkmmd loss #168
Conversation
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most of my changes are very cosmetic. Only a few mores substantive changes, but overall everything looks good. The hooks implementation is really clever and perhaps worth having everyone on the team take a look at just to learn how you did it!
checkpointer: Optional[ClientCheckpointModule] = None, | ||
lam: float = 1.0, | ||
mkmmd_loss_weight: float = 10.0, | ||
flatten_feature_extraction_layers: Dict[str, bool] = {}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setting a default empty dictionary here is dangerous (because python is dumb), since dictionaries are mutable. Consider the code below and the resulting output
from typing import Dict
class A:
def __init__(self, d: Dict[str, str] = {}) -> None:
self.d = d
a1 = A()
a2 = A()
print(a1.d)
print(a2.d)
print("Looks okay so far")
a1.d["Hello"] = "Oh No!"
print(a1.d)
print(a2.d)
print("NOOOOO!")
Output:
{}
{}
Looks okay so far
{'Hello': 'Oh No!'}
{'Hello': 'Oh No!'}
NOOOOO!
For whatever reason, python ties the reference to the default dictionary to all instances of that class. So changes to a1 unexpectedly affect a2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The standard way to implement a default value for these is to use an Optional[Dict[...]] = None
then in the __init__
to do
if flatten_feature_extraction_layers:
self.flatten_feature_extraction_layers = flatten_feature_extraction_layers
else:
self.flatten_feature_extraction_layers = {}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh nooo that is really scary 😱
) -> None: | ||
""" | ||
This client implements the MK-MMD loss function in the Ditto framework. The MK-MMD loss is a measure of the | ||
distance between the distributions of the features of the local model and init global of each round. The MK-MMD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
...initial global model of each round...
self.init_global_feature_extractor._maybe_register_hooks() | ||
|
||
def _should_optimize_betas(self, step: int) -> bool: | ||
assert self.beta_global_update_interval is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need this assertion, as it's type is just int
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is true.
log(INFO, f"Lambda: {args.lam}") | ||
log(INFO, f"Mu: {args.mu}") | ||
log(INFO, f"Feature L2 Norm Weight: {args.l2}") | ||
log(INFO, f"MKMMD Loss Depth: {args.mkmmd_loss_depth}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super minor, but probably worth reporting the beta_update_interval
here as well?
# /h/demerson/vector_repositories/fl4health_env/ \ | ||
# 0.0001 \ | ||
# 0.01 \ | ||
# 0.1 \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like perhaps mkmmd_loss_depth isn't present in this input?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is actually impressive how you detected this error 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's less impressive when I say that I just counted the arguments and realized we were 1 short 😂
# /h/demerson/vector_repositories/fl4health_env/ \ | ||
# 0.0001 \ | ||
# 0.01 \ | ||
# 0.1 \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like mkmmd_loss_depth
isn't present in this example command?
PR Type
[Feature ]
Short Description
This is a simplified updated version of MK-MMD loss implementation with additions to only Ditto and MR-MTL models.
One mk-mmd loss is implemented one to minimize distance between local features and aggregated global model. In this setting betas are updated even during training after a certain number of step based on beta_update_interval. Also in order to prevent from feature value boost a l2-norm regularizer has been added which can be controlled via feature_l2_norm parameter.
Also I generalized loss reporting function in fl4health/utils/losses as it was failing to report correctly due to addition of new keys during further steps in MOON setup.
Flamby MK-MMD loss experimental setup has also been added to the research/flamby/fedisic2019 folder.
Tests Added
Describe the tests that have been added to ensure the codes correctness, if applicable.