Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR1: Add mkmmd loss #168

Open
wants to merge 185 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
185 commits
Select commit Hold shift + click to select a range
fb43681
Add mkmmd loss
sanaAyrml Jan 23, 2024
9c9d70a
code cleaning
sanaAyrml Jan 23, 2024
81f37e3
code cleaning
sanaAyrml Jan 23, 2024
c3278a4
Update poetry
sanaAyrml Jan 23, 2024
538578c
Merge branch 'main' into mkmdd_loss
sanaAyrml Jan 23, 2024
cb5013b
Cleaning code
sanaAyrml Jan 23, 2024
374d0aa
Add fedisic experiments
sanaAyrml Jan 25, 2024
a578838
update time
sanaAyrml Jan 25, 2024
3b58481
update
sanaAyrml Jan 25, 2024
f42d777
edit
sanaAyrml Jan 25, 2024
81257ae
Fix mkmmd
sanaAyrml Jan 26, 2024
d903f50
update mkmmd fedisic
sanaAyrml Jan 26, 2024
96f1914
update mkmmd
sanaAyrml Jan 26, 2024
39c1e8f
Proposed changes to the mkmmd loss function
emersodb Jan 29, 2024
44d091c
Small fix to the documentation of the call for get_best_vertex_for_ob…
emersodb Jan 30, 2024
083174d
Update beta optimization
sanaAyrml Feb 2, 2024
66c1735
Add mkmmd to moon and fenda
sanaAyrml Feb 5, 2024
7b11f84
Update contrastive loss
sanaAyrml Feb 5, 2024
95ca579
Add beta_with_extreme_kernel_base_values computation
sanaAyrml Feb 5, 2024
d18a926
Merge branch 'main' into mkmdd_loss
sanaAyrml Feb 5, 2024
50bcf59
resolve conflicts of main merge
sanaAyrml Feb 5, 2024
6149021
fix tests
sanaAyrml Feb 5, 2024
598b98a
Fix main merge conflicts
sanaAyrml Feb 5, 2024
0b3368d
fix tests
sanaAyrml Feb 5, 2024
43bd85c
edit tests
sanaAyrml Feb 5, 2024
6efaf35
change tests
sanaAyrml Feb 5, 2024
2e285ff
fix tests
sanaAyrml Feb 5, 2024
be52d0a
fix tests
sanaAyrml Feb 5, 2024
299e939
Add moon mkmmd example
sanaAyrml Feb 5, 2024
15c25e4
Merge branch 'main' into mkmdd_loss
sanaAyrml Feb 8, 2024
616d727
debug moon client
sanaAyrml Feb 8, 2024
fd9529d
update moon
sanaAyrml Feb 8, 2024
77b5f54
update moon
sanaAyrml Feb 8, 2024
96ab534
Merge branch 'main' into mkmdd_loss
sanaAyrml Feb 9, 2024
637a825
Merge branch 'main' into mkmdd_loss
sanaAyrml Feb 9, 2024
2853cae
Merge branch 'sa-separate' into mkmdd_loss
sanaAyrml Feb 12, 2024
1ff8efa
Merge branch 'sa-separate' into mkmdd_loss
sanaAyrml Feb 12, 2024
5553fe2
Merge branch 'mkmdd_loss' of https://github.com/VectorInstitute/FL4He…
sanaAyrml Feb 16, 2024
19895c5
Merge branch 'main' into mkmdd_loss
emersodb Feb 16, 2024
67b6c88
Merge branch 'main' of https://github.com/VectorInstitute/FL4Health
sanaAyrml Feb 20, 2024
62e1536
Merge branch 'main' into mkmdd_loss
sanaAyrml Feb 20, 2024
db8396e
adding qpth to poetry
sanaAyrml Feb 20, 2024
ccf1522
Apply David's comments
sanaAyrml Feb 20, 2024
0f45556
update fenda test
sanaAyrml Feb 20, 2024
112e47e
add interval beta updates
sanaAyrml Feb 23, 2024
adca64c
Update mkmmd implementation for moon and fenda
sanaAyrml Feb 24, 2024
f612392
print optimized betas
sanaAyrml Feb 24, 2024
5289678
debug mkmmd
sanaAyrml Feb 28, 2024
cf07955
update mkmmd
sanaAyrml Feb 28, 2024
82d861a
more debug
sanaAyrml Feb 28, 2024
e5a7c97
Check again
sanaAyrml Mar 1, 2024
f0d2445
checl qp function
sanaAyrml Mar 1, 2024
9d83f43
Merge branch 'main' into mkmdd_loss
sanaAyrml Mar 12, 2024
6174732
fixing poetry lock
sanaAyrml Mar 12, 2024
f22b555
update lock file
sanaAyrml Mar 12, 2024
0590c0e
add moon mkmmd to fed ixi
sanaAyrml Mar 12, 2024
dcb1f21
Fail to add qpth in the poetry
sanaAyrml Mar 12, 2024
e2d4ce6
fixing poetry
sanaAyrml Mar 12, 2024
6011c5d
add distributed training for fl
sanaAyrml Mar 14, 2024
b82abea
fix client and server files
sanaAyrml Mar 14, 2024
0b42d78
edit server
sanaAyrml Mar 14, 2024
d373138
check fl cluster
sanaAyrml Mar 14, 2024
068e009
add better logging path
sanaAyrml Mar 14, 2024
0b6c65f
update client and server
sanaAyrml Mar 14, 2024
89c2200
maybe that is the bug
sanaAyrml Mar 14, 2024
2066889
update server
sanaAyrml Mar 14, 2024
0b67c58
add run making dir
sanaAyrml Mar 14, 2024
0bc3d5d
update experimemt dir making
sanaAyrml Mar 14, 2024
d0e0541
minor commit
sanaAyrml Mar 15, 2024
56303ef
Merge branch 'main' into mkmdd_loss
sanaAyrml Mar 24, 2024
8b5f479
Delete extra files
sanaAyrml Mar 24, 2024
bff9b1e
fix pytests and add qpth to poetry
sanaAyrml Mar 24, 2024
b579a6c
FIx aggregate loss in utils
sanaAyrml Mar 24, 2024
afea507
Fix precommit errors
sanaAyrml Mar 24, 2024
b55c3cd
Fix tests and mkmmd
sanaAyrml Mar 24, 2024
c81dbfa
add mkmmd for ditto
sanaAyrml Mar 24, 2024
71d80b8
add ditto mkmmd experimetns
sanaAyrml Mar 24, 2024
2a53c58
solve infeasibility issue
sanaAyrml Mar 25, 2024
6802811
change range
sanaAyrml Mar 25, 2024
b7e4088
fix typo and pytests
sanaAyrml Mar 25, 2024
e492c16
fix fedisic experiment
sanaAyrml Mar 25, 2024
aba0bee
Update moon model assertion for ditto
sanaAyrml Mar 25, 2024
a994b8b
update ditto mkmmd
sanaAyrml Mar 25, 2024
4dd7fc2
update errore
sanaAyrml Mar 25, 2024
408e6d2
Update beta optimization
sanaAyrml Mar 25, 2024
515d79f
fix model saving
sanaAyrml Mar 25, 2024
a63d079
add l2 feature norms
sanaAyrml Mar 26, 2024
638e69e
Update pytests for clients
sanaAyrml Mar 26, 2024
d204c2c
Update syntax errors
sanaAyrml Mar 26, 2024
87b1fef
Update ditto mkmmd client
sanaAyrml Mar 26, 2024
db443e6
fix ditto mkmmd
sanaAyrml Mar 26, 2024
3c2acb4
Update script run
sanaAyrml Mar 26, 2024
75f5bbe
delete extra log
sanaAyrml Mar 26, 2024
ae387b5
update ditto mkmmd loss
sanaAyrml Mar 27, 2024
df1a780
fix mkmmd loss computing
sanaAyrml Mar 27, 2024
3793143
Merge branch 'main' into mkmdd_loss
sanaAyrml Mar 27, 2024
459aa2c
update norm computatio
sanaAyrml Mar 27, 2024
3e8dacc
fix l2 norm
sanaAyrml Mar 27, 2024
610c472
Update fenda mkmmd client
sanaAyrml Mar 27, 2024
75373df
Update research files
sanaAyrml Mar 27, 2024
fa036e2
Update doc strings
sanaAyrml Apr 9, 2024
f825301
fix pre commit issues
sanaAyrml Apr 9, 2024
90dcebf
fix static code checks failing
sanaAyrml Apr 10, 2024
1c6aec3
Merge branch 'main' into mkmdd_loss
sanaAyrml Apr 16, 2024
bb6a322
Updating lock file
sanaAyrml Apr 17, 2024
70336d5
Detach ditto mkmmd from ditto
sanaAyrml Apr 17, 2024
055d903
Fix ditto mkmmd client comments
sanaAyrml Apr 17, 2024
eeb3d55
separate mkmmd additions to moon
sanaAyrml Apr 18, 2024
d61a77f
Separate fenda mkmmd with fenda
sanaAyrml Apr 18, 2024
25b3974
Address some more comments
sanaAyrml Apr 18, 2024
ac594a2
Fix fenda and moon mkmmd doc strings
sanaAyrml Apr 18, 2024
ad9b31a
Merge branch 'main' into mkmdd_loss
sanaAyrml Apr 18, 2024
eae0065
Update mkmmd tests
sanaAyrml Apr 18, 2024
aeb67dd
Update research files
sanaAyrml Apr 18, 2024
2b41f70
update research clients
sanaAyrml Apr 18, 2024
68632c0
fix moon_mkmmd
sanaAyrml Apr 18, 2024
dea3b57
Add mr_mtl_mkkmmd_client.py
sanaAyrml Apr 19, 2024
fef1303
Add mr_mtl_research files
sanaAyrml Apr 19, 2024
7e37b72
Add normalization to mkmmd
sanaAyrml Apr 19, 2024
606ec7f
add local beta optimization option
sanaAyrml Apr 19, 2024
99ef911
Add feature extractor base model
sanaAyrml Apr 23, 2024
b12fe77
fix static code failure
sanaAyrml Apr 23, 2024
05d3610
fix bug in feature extractor
sanaAyrml Apr 23, 2024
d012dad
update mrmtl model implementation
sanaAyrml Apr 23, 2024
db2ac05
change feature extractor name
sanaAyrml Apr 23, 2024
3d91119
Add feature extractor buffer
sanaAyrml Apr 24, 2024
da1b61b
shift mkmmd loss computation to one front
sanaAyrml Apr 24, 2024
e6c17af
fix fed isic ditto mkmmd experiment
sanaAyrml Apr 24, 2024
0f9a22a
fix mrmtl example
sanaAyrml Apr 25, 2024
cc72a88
add fedheartdisease tests
sanaAyrml Apr 25, 2024
3e8c354
fix mrmtl bug
sanaAyrml Apr 25, 2024
8444667
fix mr_mtl bug
sanaAyrml Apr 25, 2024
18767c9
fix moon runs
sanaAyrml Apr 25, 2024
cecaf60
delete extra fed heart files
sanaAyrml Apr 25, 2024
3234c06
delete test
sanaAyrml Apr 25, 2024
fb91d4b
fix experiments
sanaAyrml Apr 25, 2024
ebbec10
add some extra logging
sanaAyrml Apr 25, 2024
bef1b4a
fix new slrm files
sanaAyrml Apr 26, 2024
5c9efff
fix bash file
sanaAyrml Apr 26, 2024
ce46dfb
Fix feature extractor
sanaAyrml Apr 26, 2024
2395e12
Fix pytest failing
sanaAyrml Apr 26, 2024
3850739
check smoke test failure
sanaAyrml Apr 27, 2024
f56390d
check smoke test failure
sanaAyrml Apr 27, 2024
2f7dbc3
check smoke test failure
sanaAyrml Apr 27, 2024
f5cfdd9
check smoke test failure
sanaAyrml Apr 27, 2024
f670b02
Fix smoke tests
sanaAyrml Apr 30, 2024
dffa1ad
add extra logging
sanaAyrml May 1, 2024
c284473
Fix logging
sanaAyrml May 1, 2024
9ca1704
Fix beta update
sanaAyrml May 3, 2024
2cede39
Merge branch 'main' into mkmdd_loss
sanaAyrml May 9, 2024
afd69df
Merge main
sanaAyrml May 9, 2024
45ca9ba
add deep_mmd_client
sanaAyrml May 27, 2024
701d710
Merge branch 'main' into mkmdd_loss
sanaAyrml May 27, 2024
c5ab04d
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 27, 2024
5894dd8
fix precommit changes
sanaAyrml May 27, 2024
2952364
Merge branch 'mkmdd_loss' of https://github.com/VectorInstitute/FL4He…
sanaAyrml May 27, 2024
c7a3da7
change deepmmd name
sanaAyrml May 27, 2024
333d9e9
fix deep mmd loss implementation
sanaAyrml May 27, 2024
e1f3de1
fix deep mmd
sanaAyrml May 27, 2024
80f3c5e
fix deep mmd loss
sanaAyrml May 27, 2024
27447fb
Fix deep mmd codes
sanaAyrml May 29, 2024
cc4579b
fix deep mmd loss
sanaAyrml May 29, 2024
b1bd794
add file for fedisic
sanaAyrml May 29, 2024
5654a8e
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 29, 2024
1125de2
fix deep mmd clients
sanaAyrml May 30, 2024
a0bd92b
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 30, 2024
9b5dfc3
Merge branch 'mkmdd_loss' into add_mkmmd_loss
sanaAyrml Jun 6, 2024
6cbcc25
Delete extra files
sanaAyrml Jun 6, 2024
255ab52
Merge branch 'main' into add_mkmmd_loss
sanaAyrml Jun 6, 2024
970ccea
delete extra changes
sanaAyrml Jun 6, 2024
f69bba1
locking poetry
sanaAyrml Jun 7, 2024
f8854f2
fixing tests
sanaAyrml Jun 7, 2024
cc05d6f
fix privacy estimation
sanaAyrml Jun 7, 2024
f14dad2
try to pass test
sanaAyrml Jun 7, 2024
ab064f2
Fix pytest failing
sanaAyrml Jun 7, 2024
6fe5280
Seeing if I can pip list within this test.
emersodb Jun 12, 2024
9d36c0d
Merge branch 'main' into add_mkmmd_loss
emersodb Jun 12, 2024
a49aa70
Reverting changes because github is being a pain.
emersodb Jun 12, 2024
d5381c3
Trying to dump pip in a test again.
emersodb Jun 12, 2024
e3a997d
Putting code back to the way it was.
emersodb Jun 12, 2024
02bf3ef
Apply david's comments changes
sanaAyrml Jun 20, 2024
2c84089
Merge branch 'main' into add_mkmmd_loss
sanaAyrml Jun 27, 2024
31aef6b
Add tests for feature extractor buffer
sanaAyrml Jun 28, 2024
bf34662
Merge branch 'main' into add_mkmmd_loss
sanaAyrml Jun 28, 2024
14df7da
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add moon mkmmd example
  • Loading branch information
sanaAyrml committed Feb 5, 2024
commit 299e9394336de5dff5f4ffdfe39b7d8da4d3f15d
2 changes: 1 addition & 1 deletion research/flamby/fed_isic2019/fenda_mkmmd/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from fl4health.utils.losses import LossMeterType
from fl4health.utils.metrics import BalancedAccuracy, Metric
from fl4health.utils.random import set_all_random_seeds
from research.flamby.fed_isic2019.fenda.fenda_model import FedIsic2019FendaModel
from research.flamby.fed_isic2019.fenda_mkmmd.fenda_model import FedIsic2019FendaModel
from research.flamby.flamby_data_utils import construct_fedisic_train_val_datasets


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ MU=$7
# Create the artifact directory
mkdir "${ARTIFACT_DIR}"

RUN_NAMES=( "Run1" "Run2")
SEEDS=(2021 2022)
RUN_NAMES=( "Run1" )
SEEDS=(2021 )

echo "Python Venv Path: ${VENV_PATH}"

Expand Down
6 changes: 3 additions & 3 deletions research/flamby/fed_isic2019/fenda_mkmmd/run_hp_sweep.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ VENV_PATH=$4

# FedISIC LR Hyperparmeters from paper are not suitable for AdamW
LR_VALUES=( 0.001 )
MU_VALUES=( 0.0 0.1 1 10)
MU_VALUES=( 0.1 1 10)

SERVER_PORT=8100

# Create sweep folder
SWEEP_DIRECTORY="${ARTIFACT_DIR}hp_sweep_results_mkmmd"
SWEEP_DIRECTORY="${ARTIFACT_DIR}hp_sweep_results_mkmmd_new"
echo "Creating sweep folder at ${SWEEP_DIRECTORY}"
mkdir ${SWEEP_DIRECTORY}

Expand All @@ -45,7 +45,7 @@ for LR_VALUE in "${LR_VALUES[@]}"; do
mkdir "${EXPERIMENT_DIRECTORY}"
SERVER_ADDRESS="0.0.0.0:${SERVER_PORT}"
echo "Server Address: ${SERVER_ADDRESS}"
SBATCH_COMMAND="research/flamby/fed_isic2019/fenda/run_fold_experiment.slrm \
SBATCH_COMMAND="research/flamby/fed_isic2019/fenda_mkmmd/run_fold_experiment.slrm \
${SERVER_CONFIG_PATH} \
${EXPERIMENT_DIRECTORY} \
${DATASET_DIR} \
Expand Down
2 changes: 1 addition & 1 deletion research/flamby/fed_isic2019/fenda_mkmmd/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from fl4health.utils.config import load_config
from fl4health.utils.functions import get_all_model_parameters
from fl4health.utils.random import set_all_random_seeds
from research.flamby.fed_isic2019.fenda.fenda_model import FedIsic2019FendaModel
from research.flamby.fed_isic2019.fenda_mkmmd.fenda_model import FedIsic2019FendaModel
from research.flamby.flamby_servers.personal_server import PersonalServer
from research.flamby.utils import (
evaluate_metrics_aggregation_fn,
Expand Down
24 changes: 24 additions & 0 deletions research/flamby/fed_isic2019/moon_mkmmd/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
### Running hyperparameter sweep

To run the hyperparameter sweep you simply run the command

```bash
./research/flamby/fed_isic2019/moon/run_hp_sweep.sh \
path_to_config.yaml \
path_to_folder_for_artifacts/ \
path_to_folder_for_dataset/ \
path_to_desired_venv/
```

from the top level directory of the repository

An example is something like
``` bash
./research/flamby/fed_isic2019/moon_mkmmd/run_hp_sweep.sh \
research/flamby/fed_isic2019/moon_mkmmd/config.yaml \
research/flamby/fed_isic2019/moon_mkmmd/ \
/Users/david/Desktop/FLambyDatasets/fedisic2019/ \
/h/demerson/vector_repositories/fl4health_env/
```

In order to manipulate the grid search being conducted, you need to change the parameters for `lr`, or `mu` in the `run_hp_sweep.sh` script directly. These represent the client-side learning rate and the contrastive loss weight respectively.
Empty file.
158 changes: 158 additions & 0 deletions research/flamby/fed_isic2019/moon_mkmmd/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import argparse
import os
from logging import INFO
from pathlib import Path
from typing import Optional, Sequence, Tuple

import flwr as fl
import torch
import torch.nn as nn
from flamby.datasets.fed_isic2019 import BATCH_SIZE, LR, NUM_CLIENTS, BaselineLoss
from flwr.common.logger import log
from flwr.common.typing import Config
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer, TorchCheckpointer
from fl4health.clients.moon_client import MoonClient
from fl4health.utils.losses import LossMeterType
from fl4health.utils.metrics import BalancedAccuracy, Metric
from fl4health.utils.random import set_all_random_seeds
from research.flamby.fed_isic2019.moon_mkmmd.moon_model import FedIsic2019MoonModel
from research.flamby.flamby_data_utils import construct_fedisic_train_val_datasets


class FedIsic2019MoonClient(MoonClient):
def __init__(
self,
data_path: Path,
metrics: Sequence[Metric],
device: torch.device,
client_number: int,
learning_rate: float,
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
mkmmd_loss_weights: Tuple[float, float] = (10, 10),
checkpointer: Optional[TorchCheckpointer] = None,
) -> None:
super().__init__(
data_path=data_path,
metrics=metrics,
device=device,
loss_meter_type=loss_meter_type,
checkpointer=checkpointer,
mkmmd_loss_weights=mkmmd_loss_weights,
)
self.client_number = client_number
self.learning_rate: float = learning_rate

assert 0 <= client_number < NUM_CLIENTS
log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}")

def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
train_dataset, validation_dataset = construct_fedisic_train_val_datasets(
self.client_number, str(self.data_path)
)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False)
return train_loader, val_loader

def get_model(self, config: Config) -> nn.Module:
model: nn.Module = FedIsic2019MoonModel().to(self.device)
return model

def get_optimizer(self, config: Config) -> Optimizer:
return torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate)

def get_criterion(self, config: Config) -> _Loss:
return BaselineLoss()


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="FL Client Main")
parser.add_argument(
"--artifact_dir",
action="store",
type=str,
help="Path to save client artifacts such as logs and model checkpoints",
required=True,
)
parser.add_argument(
"--dataset_dir",
action="store",
type=str,
help="Path to the preprocessed FedIsic2019 Dataset (ex. path/to/fedisic2019)",
required=True,
)
parser.add_argument(
"--run_name",
action="store",
help="Name of the run, model checkpoints will be saved under a subfolder with this name",
required=True,
)
parser.add_argument(
"--server_address",
action="store",
type=str,
help="Server Address for the clients to communicate with the server through",
default="0.0.0.0:8080",
)
parser.add_argument(
"--client_number",
action="store",
type=int,
help="Number of the client for dataset loading (should be 0-5 for FedIsic2019)",
required=True,
)
parser.add_argument(
"--learning_rate", action="store", type=float, help="Learning rate for local optimization", default=LR
)
parser.add_argument(
"--seed",
action="store",
type=int,
help="Seed for the random number generators across python, torch, and numpy",
required=False,
)
parser.add_argument(
"--mu",
action="store",
type=float,
help="Weight for the auxiliary losses",
required=False,
)
parser.add_argument(
"--gamma",
action="store",
type=float,
help="Weight for the auxiliary lossess",
required=False,
)
args = parser.parse_args()

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log(INFO, f"Device to be used: {DEVICE}")
log(INFO, f"Server Address: {args.server_address}")
log(INFO, f"Learning Rate: {args.learning_rate}")

# Set the random seed for reproducibility
set_all_random_seeds(args.seed)

checkpoint_dir = os.path.join(args.artifact_dir, args.run_name)
checkpoint_name = f"client_{args.client_number}_best_model.pkl"
checkpointer = BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name, maximize=False)

client = FedIsic2019MoonClient(
data_path=Path(args.dataset_dir),
metrics=[BalancedAccuracy("FedIsic2019_balanced_accuracy")],
device=DEVICE,
client_number=args.client_number,
learning_rate=args.learning_rate,
checkpointer=checkpointer,
mkmmd_loss_weights=(args.mu, args.gamma),
)

fl.client.start_numpy_client(server_address=args.server_address, client=client)

# Shutdown the client gracefully
client.shutdown()
8 changes: 8 additions & 0 deletions research/flamby/fed_isic2019/moon_mkmmd/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Parameters that describe server
n_server_rounds: 15 # The number of rounds to run FL

# Parameters that describe clients
n_clients: 6 # The number of clients in the FL experiment
local_epochs: 1 # The number of epochs to complete for client (NOT USED FOR FLAMBY)
batch_size: 64 # The batch size for client training (NOT USED FOR FLAMBY)
local_steps: 100 # The number of local training steps to perform.
83 changes: 83 additions & 0 deletions research/flamby/fed_isic2019/moon_mkmmd/moon_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from typing import Optional

import torch
import torch.nn as nn
from efficientnet_pytorch import EfficientNet
from efficientnet_pytorch.utils import url_map
from torch.utils import model_zoo

from fl4health.model_bases.moon_base import MoonModel
from research.flamby.utils import shutoff_batch_norm_tracking


def from_pretrained(model_name: str, in_channels: int = 3, include_top: bool = False) -> EfficientNet:
# There is a bug in the EfficientNet implementation if you want to strip off the top layer of the network, but
# still load the pre-trained weights. So we do it ourselves here.
model = EfficientNet.from_name(model_name, include_top=include_top)
state_dict = model_zoo.load_url(url_map[model_name])
state_dict.pop("_fc.weight")
state_dict.pop("_fc.bias")
model.load_state_dict(state_dict, strict=False)
model._change_in_channels(in_channels)
return model


class HeadClassifier(nn.Module):
"""Moon head module"""

def __init__(self, stack_output_dimension: int):
super().__init__()
self.fc1 = nn.Linear(stack_output_dimension, 8)
self.dropout = nn.Dropout(0.2)

def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
x = self.dropout(input_tensor)
x = self.fc1(x)
return x


class BaseEfficientNet(nn.Module):
"""Moon feature extractor module
We use the EfficientNets architecture that many participants in the ISIC
competition have identified to work best.
See here the [reference paper](https://arxiv.org/abs/1905.11946)
Thank you to [Luke Melas-Kyriazi](https://github.com/lukemelas) for his
[pytorch reimplementation of EfficientNets]
(https://github.com/lukemelas/EfficientNet-PyTorch).
When loading the EfficientNet-B0 model, we strip off the FC layer to use the model as a feature extractor.
There is an option to freeze a subset of the layers to reduce the number of trainable parameters. However,
it is not used in the Moon experiments.
"""

def __init__(self, frozen_blocks: Optional[int] = 13, turn_off_bn_tracking: bool = False):
super().__init__()
# include_top ensures that we just use feature extraction in the forward pass
self.base_model = from_pretrained("efficientnet-b0", include_top=False)
if frozen_blocks:
self.freeze_layers(frozen_blocks)
if turn_off_bn_tracking:
shutoff_batch_norm_tracking(self.base_model)

def freeze_layers(self, frozen_blocks: int) -> None:
# We freeze the bottom layers of the network. We always freeze the _conv_stem module, the _bn0 module and then
# we iterate throught the blocks freezing the specified number up to 15 (all of them)

# Freeze the first two layers
self.base_model._modules["_conv_stem"].requires_grad_(False)
self.base_model._modules["_bn0"].requires_grad_(False)
# Now we iterate through the block modules and freeze a certain number of them.
frozen_blocks = min(frozen_blocks, 15)
for block_index in range(frozen_blocks):
self.base_model._modules["_blocks"][block_index].requires_grad_(False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.base_model(x)
x = x.flatten(start_dim=1)
return x


class FedIsic2019MoonModel(MoonModel):
def __init__(self, frozen_blocks: Optional[int] = None, turn_off_bn_tracking: bool = False) -> None:
base_module = BaseEfficientNet(frozen_blocks, turn_off_bn_tracking=turn_off_bn_tracking)
head_module = HeadClassifier(1280)
super().__init__(base_module, head_module)
Loading
Loading