Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR3: Add deep_mmd_loss files #170

Open
wants to merge 2 commits into
base: add_mkmmd_loss
Choose a base branch
from

Conversation

sanaAyrml
Copy link
Collaborator

PR Type

[Feature]

Short Description

This is a tentative implementation for deep mmd loss.

Tests Added

No tests added yet.

@sanaAyrml sanaAyrml requested a review from emersodb June 7, 2024 06:38
each batch. Defaults to LossMeterType.AVERAGE.
checkpointer (Optional[TorchCheckpointer], optional): Checkpointer to be used for client-side
checkpointing. Defaults to None.
metrics_reporter (Optional[MetricsReporter], optional): A metrics reporter instance to record the metrics
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is probably an issue with the DittoClient docstring as well, but the metrics_reporter is technically not an arg in this implementation.

size_feature_extraction_layers: Dict[str, int] = {},
) -> None:
"""
This client implements the DEEP-MMD loss function in the Ditto framework.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super minor, but is DEEP an acronym here? If not, I'd say we can use the capitalization scheme of Deep-MMD throughout?

deep_mmd_loss_weight (float, optional): weight applied to the DEEP-MMD loss. Defaults to 10.0.
flatten_feature_extraction_layers (Dict[str, bool], optional): Dictionary of layers to extract features
from them what is the flattened feature size. Defaults to {}. If it is -1 then the layer is not
flattened.
Copy link
Collaborator

@emersodb emersodb Jun 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like size_feature_extraction_layers is missing from the args documentation here or maybe it's sort of squashed together in the docs for flatten_feature_extraction_layers?

lam: float = 1.0,
deep_mmd_loss_weight: float = 10.0,
flatten_feature_extraction_layers: Dict[str, bool] = {},
size_feature_extraction_layers: Dict[str, int] = {},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We noted that we don't want mutable defaults in the previous PR. Just wanted to put this here as a reminder to change these around too 🙂

features = self.local_feature_extractor.get_extracted_features()
if self.deep_mmd_loss_weight != 0:
# Compute the features of the init_global_model
_ = self.init_global_model(input)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need to catch this with _ you're not going to store it anyway.

EvaluationLosses: an instance of EvaluationLosses containing checkpoint loss and additional losses
indexed by name.
"""
for layer in self.flatten_feature_extraction_layers.keys():
Copy link
Collaborator

@emersodb emersodb Jun 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're going to be indexing into self.deep_mmd_losses anyway, could we simply do

for layer_loss_module in self.deep_mmd_losses.values():
    layer_loss_module.training = False

For Ditto, we do this process in validate and train_by_steps/train_by_epochs for the global model, maybe we can just do this there?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's still worth overriding compute_evaluation_loss and compute_training_loss and asserting that all layer_loss_module.training == False or vice versa though to be safe 🙂

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also might be missing this, but I don't see where we set layer_loss_module.training to True in the client. Based on the loss code, this would mean that we won't run training of the deep kernels after the first server round, which I think we want to keep doing?

if self.deep_mmd_loss_weight != 0:
# Compute DEEP-MMD loss
total_deep_mmd_loss = torch.tensor(0.0, device=self.device)
for layer in self.flatten_feature_extraction_layers.keys():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above, since we're not accessing components of flatten_feature_extraction_layers, it might be more straightforward to do:

for layer, layer_loss_module in self.deep_mmd_losses.items():
    layer_deep_mmd_loss = layer_loss_module(
        features[layer], features[" ".join(["init_global", layer])]
    )
    additional_losses["_".join(["deep_mmd_loss", layer])] = layer_deep_mmd_loss
    total_deep_mmd_loss += layer_deep_mmd_loss


def __init__(self, x_in: int, H: int, x_out: int):
"""Init latent features."""
super(ModelLatentF, self).__init__()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this isn't all your code, but this can be changed to super().__init__()

def __init__(self, x_in: int, H: int, x_out: int):
"""Init latent features."""
super(ModelLatentF, self).__init__()
self.restored = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is restored used anywhere else?

class ModelLatentF(torch.nn.Module):
"""Latent space for both domains."""

def __init__(self, x_in: int, H: int, x_out: int):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if H here has a specific meaning elsewhere, but it essentially looks like the hidden dimension of this kernel network. So maybe we can just call it hidden_dimension?

def forward(self, input: torch.Tensor) -> torch.Tensor:
"""Forward the LeNet."""
fealant = self.latent(input)
return fealant
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol...weird word

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe feature_latent_map?

super().__init__()
self.device = device
self.lr = lr
self.layer_name = layer_name
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any utility to the layer_name here? It seems like it's not used, but I may be missing something.

self.sigma0OPT.requires_grad = self.training

# Initialize optimizers
self.optimizer_F = torch.optim.Adam(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can use AdamW here, just for the decay correction?

self.featurizer = ModelLatentF(input_size, hidden_size, output_size).to(self.device)

# Initialize parameters
self.epsilonOPT: torch.Tensor = torch.log(torch.from_numpy(np.random.rand(1) * 10 ** (-10)).to(self.device))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to go with our naming conventions, maybe we can do *_opt for all of these variables with OPT suffixes?

list(self.featurizer.parameters()) + [self.epsilonOPT] + [self.sigmaOPT] + [self.sigma0OPT], lr=self.lr
)

def Pdist2(self, x: torch.Tensor, y: Optional[torch.Tensor]) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe expand this to pairwise_distiance_squared?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we don't leverage the fact that y can be none to get the distances of x with itself. Maybe we just drop that option and require y to be passed to simplify this function.

else:
y = x
y_norm = x_norm.view(1, -1)
Pdist = x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(y, 0, 1))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same with this, maybe just paired_distance?

return

def compute_kernel(self, X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
"""Train the kernel."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually not training the kernel, but rather doing pure inference I think?


return mmd_value_temp

def forward(self, Xs: torch.Tensor, Xt: torch.Tensor) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So based on our implementation, am I correct in thinking that during training, we'll also be updating the kernel values at each step based on the batches? If this doesn't quite work, another thing we could do is update the deep kernels periodically, like we do with MKMMD where we "optimize" the kernels for a frozen set of models.


def MMDu(
self,
Fea: torch.Tensor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe rename to features?

# Compute output of deep network
model_output = self.featurizer(features)
# Compute epsilon, sigma and sigma_0
ep = torch.exp(self.epsilonOPT) / (1 + torch.exp(self.epsilonOPT))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename epsilon and note that it is the epsilon in $\kappa_w(x, y)$ in the paper

self,
Fea: torch.Tensor,
len_s: int,
Fea_org: torch.Tensor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename original_features for clarity.

# Compute epsilon, sigma and sigma_0
ep = torch.exp(self.epsilonOPT) / (1 + torch.exp(self.epsilonOPT))
sigma = self.sigmaOPT**2
sigma0_u = self.sigma0OPT**2
Copy link
Collaborator

@emersodb emersodb Jun 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

based on the implementation of MMDu I would suggest renaming sigma0 to sigma_phi, sigma0OPT to sigma_phi_opt and sigma0_u to sigma_phi (since there doesn't seem to be any reason to have _u in there anyway. Similarly, anything that is sigma or sigmaOPT can be sigma_q or sigma_q_opt to match the notation of the paper.

V1 = torch.dot(hh.sum(1) / ny, hh.sum(1) / ny) / ny
V2 = (hh).sum() / (nx) / nx
varEst = 4 * (V1 - V2**2)
return mmd2, varEst, Kxyxy
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct me if I'm wrong, but in all instances where this function is called, we don't make use of Kxyxy. Can we just drop it from the return values?

Kxyxy = torch.cat((Kxxy, Kyxy), 0)
nx = Kx.shape[0]
ny = Ky.shape[0]
is_unbiased = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps, rather than hardcoding this here, we can turn it into a class variable and default it to True (i.e. self.is_unbiased? That way it can be pulled out of this function.

Fea_org=features.view(features.shape[0], -1),
sigma=sigma,
sigma0=sigma0_u,
epsilon=ep,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want to set is_var_computed = False in this call to save some calculation

Fea_org=features.view(features.shape[0], -1),
sigma=sigma,
sigma0=sigma0_u,
epsilon=ep,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe explicitly set is_var_computed=True in this call rather than leaving it to the default?

Dxx = self.Pdist2(X, X)
Dyy = self.Pdist2(Y, Y)
Dxy = self.Pdist2(X, Y)
Dxx_org = self.Pdist2(X_org, X_org)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can put these inside the if is_smooth block, since they only need to be computed if that is true. Otherwise they aren't used.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also rename these to expand org to original just to be explicit

X = Fea[0:len_s, :] # fetch the sample 1 (features of deep networks)
Y = Fea[len_s:, :] # fetch the sample 2 (features of deep networks)
X_org = Fea_org[0:len_s, :] # fetch the original sample 1
Y_org = Fea_org[len_s:, :] # fetch the original sample 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe rename these to X_original and Y_original?

Y = Fea[len_s:, :] # fetch the sample 2 (features of deep networks)
X_org = Fea_org[0:len_s, :] # fetch the original sample 1
Y_org = Fea_org[len_s:, :] # fetch the original sample 2
L = 1 # generalized Gaussian (if L>1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than hardcoding here, we can make this a class variable that defaults to 1?

@@ -0,0 +1,224 @@
from typing import Optional, Tuple
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's worth creating some tests for this loss function before doing any of the refactors I suggest below just to make sure we don't inadvertently break anything in the calculations. It could be as simple as generating a random X, Y and recording the outputs of each function. That way we know we did something if those outputs change with our refactors.

Kx = torch.exp(-Dxx / sigma0)
Ky = torch.exp(-Dyy / sigma0)
Kxy = torch.exp(-Dxy / sigma0)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe here just add a comment noting what Kx, Ky, Kxy represent. That is, they are $k_w(x_i, x_j)$, $k_w(y_i, y_j)$ and $k_w(x_i,y_j)$ for all i, j in the samples of X and Y

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can also name these to be a bit more verbose. That is, Kx -> kernel_xx K_y -> kernel_yy and K_xy -> kernel_xy. We could do the same for Dx, Dy, Dxy -> distance_xx, distance_yy, distance_xy

"""compute value of MMD and std of MMD using kernel matrix."""
Kxxy = torch.cat((Kx, Kxy), 1)
Kyxy = torch.cat((Kxy.transpose(0, 1), Ky), 1)
Kxyxy = torch.cat((Kxxy, Kyxy), 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's worth a comment that this final matrix (Kxyxy) is composed of the blocks [[Kx, Kxy], [Kxy^T, Ky]]. Just to help any reader understand what's actually being constructed here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it seems like this tensor isn't used anywhere, which would make its formation unnecessary. Maybe I'm missing something though?

xx = torch.div((torch.sum(Kx)), (nx * nx))
yy = torch.div((torch.sum(Ky)), (ny * ny))
# one-sample U-statistic.
if use_1sample_U:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a no-op? That is, regardless of whether this boolean is true or false, we do the same calculation?

ny = Ky.shape[0]
is_unbiased = True
if is_unbiased:
xx = torch.div((torch.sum(Kx) - torch.sum(torch.diag(Kx))), (nx * (nx - 1)))
Copy link
Collaborator

@emersodb emersodb Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe note that these calculations correspond to the calculation of $\hat{\text{MMD}}_u^2$ in equation 2 of the paper?

return fealant


class DeepMmdLoss(torch.nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's worth recording the paper that this is being taken from somewhere here 🙂

# one-sample U-statistic.
if use_1sample_U:
xy = torch.div((torch.sum(Kxy) - torch.sum(torch.diag(Kxy))), (nx * (ny - 1)))
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm having trouble figuring out where the calculation below is described in their paper. It doesn't seem like it's a well defined calculation to be honest. I'm tempted to suggest we remove it and only offer the is_unbiased boolean to switch between the estimates $\hat{\text{MMD}}_u^2$ and $\hat{\text{MMD}}_b^2$ (which is what is computed below)

else:
xy = torch.div(torch.sum(Kxy), (nx * ny))
mmd2 = xx - 2 * xy + yy
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's worth noting that we calculate $\hat{\text{MMD}}_b^2$ from the paper (defined below Equation (2)) here for clarity

xy = torch.div((torch.sum(Kxy) - torch.sum(torch.diag(Kxy))), (nx * (ny - 1)))
else:
xy = torch.div(torch.sum(Kxy), (nx * ny))
mmd2 = xx - 2 * xy + yy
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This calculation can be moved outside the if and else blocks rather than being repeated in both blocks.

mmd2 = xx - 2 * xy + yy
if not is_var_computed:
return mmd2, None, Kxyxy
hh = Kx + Ky - Kxy - Kxy.transpose(0, 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be more consistent with the paper notation, I'd rename hh to h_ij?

Copy link
Collaborator

@emersodb emersodb Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd add a comment above that we're computing the estimate in Equation (5) from the paper (without the lambda shift).

if not is_var_computed:
return mmd2, None, Kxyxy
hh = Kx + Ky - Kxy - Kxy.transpose(0, 1)
V1 = torch.dot(hh.sum(1) / ny, hh.sum(1) / ny) / ny
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a weird way to write this. My suggestion, to make it look more like Equation (5) from the paper is to do

V1 = (4.0/ny**3)*(torch.dot(hh.sum(1), hh.sum(1)))
V2 = (4.0/nx**4)*(hh.sum()**2)
variance_estimate = V1 - V2

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed in a comment below, but I would move the addition of the lambda shift into this calculation to better align it with Equation (5) from the paper. That is

variance_estimate = V1 - V2 + (10 ** (-8))

# ------------------------------
# Train deep network for MMD-D
# ------------------------------
# Initialize optimizer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't really initializing the optimizer. It's just clearing the gradients.

if mmd_var_temp is None:
raise AssertionError("Error: Variance of MMD is not computed. Please set is_var_computed=True.")
mmd_std_temp = torch.sqrt(mmd_var_temp + 10 ** (-8))
STAT_u = torch.div(-1 * mmd_value_temp, mmd_std_temp)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay with calling this stat_u, but I would kill the all caps 🙂.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a comment that we're trying to maximize the ratio (mmd_value_stimate/mmd_std_estimate), so we multiply by -1 to be compatible with gradient descent.

ep = torch.exp(self.epsilonOPT) / (1 + torch.exp(self.epsilonOPT))
sigma = self.sigmaOPT**2
sigma0_u = self.sigma0OPT**2
# Compute Compute J (STAT_u)
Copy link
Collaborator

@emersodb emersodb Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd add to this comment that what we're forming is the ratio in Equation (4) from the paper in stat_u which is referred to as $\hat{J}_{\lambda}$ therein.

sigma = self.sigmaOPT**2
sigma0_u = self.sigma0OPT**2
# Compute Compute J (STAT_u)
mmd_value_temp, mmd_var_temp, _ = self.MMDu(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why they are suffixing with _temp here. Maybe we just call it ..._estimate?

)
if mmd_var_temp is None:
raise AssertionError("Error: Variance of MMD is not computed. Please set is_var_computed=True.")
mmd_std_temp = torch.sqrt(mmd_var_temp + 10 ** (-8))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm good with doing the sqrt here, since its how we form $\hat{J}_{\lambda}$. However, I'd move the lambda shift here into the formation of the estimate to collocate it better with the estimate calculation. So this would become something like

mmd_std_estimate = torch.sqrt(mmd_var_estimate)

STAT_u.backward()
# Update weights using gradient descent
self.optimizer_F.step()
return
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Return is unnecessary here.

sigma = self.sigmaOPT**2
sigma0_u = self.sigma0OPT**2
# Compute Compute J (STAT_u)
mmd_value_temp, _, _ = self.MMDu(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As with the training function, I'd suggest renaming this to mmd_value_estimate from _temp

checkpointer=checkpointer,
lam=lam,
deep_mmd_loss_weight=deep_mmd_loss_weight,
flatten_feature_extraction_layers={key: True for key in size_feature_extraction_layers},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also form this dictionary above the super call as you're doing with size_feature_extraction_layers?

"--mu",
action="store",
type=float,
help="Weight for the mkmmd losses",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deep MMD losses 🙂

help="Weight for the mkmmd losses",
required=False,
)
parser.add_argument(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this argument is being used anywhere?

required=False,
default=1,
)
parser.add_argument(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, I don't think this arg is used anywhere either

# client_side_learning_rate_value \
# lambda value \
# mu value \
# depp_mmd_loss_depth value \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Johnny Depp MMD Loss!

# /h/demerson/vector_repositories/fl4health_env/ \
# 0.0001 \
# 0.01 \
# 0.1 \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the depth argument is missing in this example?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants