Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR1: Add mkmmd loss #168

Open
wants to merge 181 commits into
base: main
Choose a base branch
from
Open

PR1: Add mkmmd loss #168

wants to merge 181 commits into from

Conversation

sanaAyrml
Copy link
Collaborator

@sanaAyrml sanaAyrml commented Jun 7, 2024

PR Type

[Feature ]

Short Description

This is a simplified updated version of MK-MMD loss implementation with additions to only Ditto and MR-MTL models.

One mk-mmd loss is implemented one to minimize distance between local features and aggregated global model. In this setting betas are updated even during training after a certain number of step based on beta_update_interval. Also in order to prevent from feature value boost a l2-norm regularizer has been added which can be controlled via feature_l2_norm parameter.

Also I generalized loss reporting function in fl4health/utils/losses as it was failing to report correctly due to addition of new keys during further steps in MOON setup.

Flamby MK-MMD loss experimental setup has also been added to the research/flamby/fedisic2019 folder.

Tests Added

Describe the tests that have been added to ensure the codes correctness, if applicable.

@sanaAyrml sanaAyrml changed the title Add mkmmd loss PR1: Add mkmmd loss Jun 7, 2024
@sanaAyrml sanaAyrml requested a review from emersodb June 7, 2024 06:38
Copy link
Collaborator

@emersodb emersodb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of my changes are very cosmetic. Only a few mores substantive changes, but overall everything looks good. The hooks implementation is really clever and perhaps worth having everyone on the team take a look at just to learn how you did it!

fl4health/clients/basic_client.py Show resolved Hide resolved
checkpointer: Optional[ClientCheckpointModule] = None,
lam: float = 1.0,
mkmmd_loss_weight: float = 10.0,
flatten_feature_extraction_layers: Dict[str, bool] = {},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting a default empty dictionary here is dangerous (because python is dumb), since dictionaries are mutable. Consider the code below and the resulting output

from typing import Dict

class A:
    def __init__(self, d: Dict[str, str] = {}) -> None:
        self.d = d

a1 = A()
a2 = A()

print(a1.d)
print(a2.d)
print("Looks okay so far")

a1.d["Hello"] = "Oh No!"

print(a1.d)
print(a2.d)
print("NOOOOO!")

Output:

{}
{}
Looks okay so far
{'Hello': 'Oh No!'}
{'Hello': 'Oh No!'}
NOOOOO!

For whatever reason, python ties the reference to the default dictionary to all instances of that class. So changes to a1 unexpectedly affect a2

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The standard way to implement a default value for these is to use an Optional[Dict[...]] = None then in the __init__ to do

if flatten_feature_extraction_layers:
    self.flatten_feature_extraction_layers = flatten_feature_extraction_layers
else:
    self.flatten_feature_extraction_layers = {}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nooo that is really scary 😱

) -> None:
"""
This client implements the MK-MMD loss function in the Ditto framework. The MK-MMD loss is a measure of the
distance between the distributions of the features of the local model and init global of each round. The MK-MMD
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...initial global model of each round...

self.init_global_feature_extractor._maybe_register_hooks()

def _should_optimize_betas(self, step: int) -> bool:
assert self.beta_global_update_interval is not None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need this assertion, as it's type is just int?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is true.

fl4health/clients/mkmmd_clients/ditto_mkmmd_client.py Outdated Show resolved Hide resolved
log(INFO, f"Lambda: {args.lam}")
log(INFO, f"Mu: {args.mu}")
log(INFO, f"Feature L2 Norm Weight: {args.l2}")
log(INFO, f"MKMMD Loss Depth: {args.mkmmd_loss_depth}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super minor, but probably worth reporting the beta_update_interval here as well?

# /h/demerson/vector_repositories/fl4health_env/ \
# 0.0001 \
# 0.01 \
# 0.1 \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like perhaps mkmmd_loss_depth isn't present in this input?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is actually impressive how you detected this error 😅

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's less impressive when I say that I just counted the arguments and realized we were 1 short 😂

# /h/demerson/vector_repositories/fl4health_env/ \
# 0.0001 \
# 0.01 \
# 0.1 \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like mkmmd_loss_depth isn't present in this example command?

research/flamby/fed_isic2019/mr_mtl_mkmmd/run_hp_sweep.sh Outdated Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants