-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PR3: Add deep_mmd_loss files #170
base: add_mkmmd_loss
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
each batch. Defaults to LossMeterType.AVERAGE. | ||
checkpointer (Optional[TorchCheckpointer], optional): Checkpointer to be used for client-side | ||
checkpointing. Defaults to None. | ||
metrics_reporter (Optional[MetricsReporter], optional): A metrics reporter instance to record the metrics |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is probably an issue with the DittoClient docstring as well, but the metrics_reporter
is technically not an arg in this implementation.
size_feature_extraction_layers: Dict[str, int] = {}, | ||
) -> None: | ||
""" | ||
This client implements the DEEP-MMD loss function in the Ditto framework. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super minor, but is DEEP an acronym here? If not, I'd say we can use the capitalization scheme of Deep-MMD throughout?
deep_mmd_loss_weight (float, optional): weight applied to the DEEP-MMD loss. Defaults to 10.0. | ||
flatten_feature_extraction_layers (Dict[str, bool], optional): Dictionary of layers to extract features | ||
from them what is the flattened feature size. Defaults to {}. If it is -1 then the layer is not | ||
flattened. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like size_feature_extraction_layers
is missing from the args documentation here or maybe it's sort of squashed together in the docs for flatten_feature_extraction_layers
?
lam: float = 1.0, | ||
deep_mmd_loss_weight: float = 10.0, | ||
flatten_feature_extraction_layers: Dict[str, bool] = {}, | ||
size_feature_extraction_layers: Dict[str, int] = {}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We noted that we don't want mutable defaults in the previous PR. Just wanted to put this here as a reminder to change these around too 🙂
features = self.local_feature_extractor.get_extracted_features() | ||
if self.deep_mmd_loss_weight != 0: | ||
# Compute the features of the init_global_model | ||
_ = self.init_global_model(input) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need to catch this with _
you're not going to store it anyway.
EvaluationLosses: an instance of EvaluationLosses containing checkpoint loss and additional losses | ||
indexed by name. | ||
""" | ||
for layer in self.flatten_feature_extraction_layers.keys(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you're going to be indexing into self.deep_mmd_losses
anyway, could we simply do
for layer_loss_module in self.deep_mmd_losses.values():
layer_loss_module.training = False
For Ditto, we do this process in validate
and train_by_steps
/train_by_epochs
for the global model, maybe we can just do this there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's still worth overriding compute_evaluation_loss
and compute_training_loss
and asserting that all layer_loss_module.training == False
or vice versa though to be safe 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also might be missing this, but I don't see where we set layer_loss_module.training
to True
in the client. Based on the loss code, this would mean that we won't run training of the deep kernels after the first server round, which I think we want to keep doing?
if self.deep_mmd_loss_weight != 0: | ||
# Compute DEEP-MMD loss | ||
total_deep_mmd_loss = torch.tensor(0.0, device=self.device) | ||
for layer in self.flatten_feature_extraction_layers.keys(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As above, since we're not accessing components of flatten_feature_extraction_layers
, it might be more straightforward to do:
for layer, layer_loss_module in self.deep_mmd_losses.items():
layer_deep_mmd_loss = layer_loss_module(
features[layer], features[" ".join(["init_global", layer])]
)
additional_losses["_".join(["deep_mmd_loss", layer])] = layer_deep_mmd_loss
total_deep_mmd_loss += layer_deep_mmd_loss
|
||
def __init__(self, x_in: int, H: int, x_out: int): | ||
"""Init latent features.""" | ||
super(ModelLatentF, self).__init__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know this isn't all your code, but this can be changed to super().__init__()
def __init__(self, x_in: int, H: int, x_out: int): | ||
"""Init latent features.""" | ||
super(ModelLatentF, self).__init__() | ||
self.restored = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is restored used anywhere else?
class ModelLatentF(torch.nn.Module): | ||
"""Latent space for both domains.""" | ||
|
||
def __init__(self, x_in: int, H: int, x_out: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if H
here has a specific meaning elsewhere, but it essentially looks like the hidden dimension of this kernel network. So maybe we can just call it hidden_dimension
?
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
"""Forward the LeNet.""" | ||
fealant = self.latent(input) | ||
return fealant |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lol...weird word
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe feature_latent_map
?
super().__init__() | ||
self.device = device | ||
self.lr = lr | ||
self.layer_name = layer_name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any utility to the layer_name
here? It seems like it's not used, but I may be missing something.
self.sigma0OPT.requires_grad = self.training | ||
|
||
# Initialize optimizers | ||
self.optimizer_F = torch.optim.Adam( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can use AdamW here, just for the decay correction?
self.featurizer = ModelLatentF(input_size, hidden_size, output_size).to(self.device) | ||
|
||
# Initialize parameters | ||
self.epsilonOPT: torch.Tensor = torch.log(torch.from_numpy(np.random.rand(1) * 10 ** (-10)).to(self.device)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to go with our naming conventions, maybe we can do *_opt for all of these variables with OPT suffixes?
list(self.featurizer.parameters()) + [self.epsilonOPT] + [self.sigmaOPT] + [self.sigma0OPT], lr=self.lr | ||
) | ||
|
||
def Pdist2(self, x: torch.Tensor, y: Optional[torch.Tensor]) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe expand this to pairwise_distiance_squared
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like we don't leverage the fact that y can be none to get the distances of x with itself. Maybe we just drop that option and require y to be passed to simplify this function.
else: | ||
y = x | ||
y_norm = x_norm.view(1, -1) | ||
Pdist = x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(y, 0, 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same with this, maybe just paired_distance
?
return | ||
|
||
def compute_kernel(self, X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor: | ||
"""Train the kernel.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is actually not training the kernel, but rather doing pure inference I think?
|
||
return mmd_value_temp | ||
|
||
def forward(self, Xs: torch.Tensor, Xt: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So based on our implementation, am I correct in thinking that during training, we'll also be updating the kernel values at each step based on the batches? If this doesn't quite work, another thing we could do is update the deep kernels periodically, like we do with MKMMD where we "optimize" the kernels for a frozen set of models.
|
||
def MMDu( | ||
self, | ||
Fea: torch.Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe rename to features
?
# Compute output of deep network | ||
model_output = self.featurizer(features) | ||
# Compute epsilon, sigma and sigma_0 | ||
ep = torch.exp(self.epsilonOPT) / (1 + torch.exp(self.epsilonOPT)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rename epsilon
and note that it is the epsilon in
self, | ||
Fea: torch.Tensor, | ||
len_s: int, | ||
Fea_org: torch.Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rename original_features
for clarity.
# Compute epsilon, sigma and sigma_0 | ||
ep = torch.exp(self.epsilonOPT) / (1 + torch.exp(self.epsilonOPT)) | ||
sigma = self.sigmaOPT**2 | ||
sigma0_u = self.sigma0OPT**2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
based on the implementation of MMDu
I would suggest renaming sigma0
to sigma_phi
, sigma0OPT
to sigma_phi_opt
and sigma0_u
to sigma_phi
(since there doesn't seem to be any reason to have _u
in there anyway. Similarly, anything that is sigma or sigmaOPT can be sigma_q
or sigma_q_opt
to match the notation of the paper.
V1 = torch.dot(hh.sum(1) / ny, hh.sum(1) / ny) / ny | ||
V2 = (hh).sum() / (nx) / nx | ||
varEst = 4 * (V1 - V2**2) | ||
return mmd2, varEst, Kxyxy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct me if I'm wrong, but in all instances where this function is called, we don't make use of Kxyxy
. Can we just drop it from the return values?
Kxyxy = torch.cat((Kxxy, Kyxy), 0) | ||
nx = Kx.shape[0] | ||
ny = Ky.shape[0] | ||
is_unbiased = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps, rather than hardcoding this here, we can turn it into a class variable and default it to True (i.e. self.is_unbiased
? That way it can be pulled out of this function.
Fea_org=features.view(features.shape[0], -1), | ||
sigma=sigma, | ||
sigma0=sigma0_u, | ||
epsilon=ep, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we want to set is_var_computed = False
in this call to save some calculation
Fea_org=features.view(features.shape[0], -1), | ||
sigma=sigma, | ||
sigma0=sigma0_u, | ||
epsilon=ep, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe explicitly set is_var_computed=True
in this call rather than leaving it to the default?
Dxx = self.Pdist2(X, X) | ||
Dyy = self.Pdist2(Y, Y) | ||
Dxy = self.Pdist2(X, Y) | ||
Dxx_org = self.Pdist2(X_org, X_org) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can put these inside the if is_smooth
block, since they only need to be computed if that is true. Otherwise they aren't used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe also rename these to expand org
to original
just to be explicit
X = Fea[0:len_s, :] # fetch the sample 1 (features of deep networks) | ||
Y = Fea[len_s:, :] # fetch the sample 2 (features of deep networks) | ||
X_org = Fea_org[0:len_s, :] # fetch the original sample 1 | ||
Y_org = Fea_org[len_s:, :] # fetch the original sample 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe rename these to X_original
and Y_original
?
Y = Fea[len_s:, :] # fetch the sample 2 (features of deep networks) | ||
X_org = Fea_org[0:len_s, :] # fetch the original sample 1 | ||
Y_org = Fea_org[len_s:, :] # fetch the original sample 2 | ||
L = 1 # generalized Gaussian (if L>1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than hardcoding here, we can make this a class variable that defaults to 1?
@@ -0,0 +1,224 @@ | |||
from typing import Optional, Tuple |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's worth creating some tests for this loss function before doing any of the refactors I suggest below just to make sure we don't inadvertently break anything in the calculations. It could be as simple as generating a random X, Y and recording the outputs of each function. That way we know we did something if those outputs change with our refactors.
Kx = torch.exp(-Dxx / sigma0) | ||
Ky = torch.exp(-Dyy / sigma0) | ||
Kxy = torch.exp(-Dxy / sigma0) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe here just add a comment noting what Kx, Ky, Kxy represent. That is, they are
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we can also name these to be a bit more verbose. That is, Kx
-> kernel_xx
K_y
-> kernel_yy
and K_xy
-> kernel_xy
. We could do the same for Dx
, Dy
, Dxy
-> distance_xx
, distance_yy
, distance_xy
"""compute value of MMD and std of MMD using kernel matrix.""" | ||
Kxxy = torch.cat((Kx, Kxy), 1) | ||
Kyxy = torch.cat((Kxy.transpose(0, 1), Ky), 1) | ||
Kxyxy = torch.cat((Kxxy, Kyxy), 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's worth a comment that this final matrix (Kxyxy
) is composed of the blocks [[Kx
, Kxy
], [Kxy^T
, Ky
]]. Just to help any reader understand what's actually being constructed here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, it seems like this tensor isn't used anywhere, which would make its formation unnecessary. Maybe I'm missing something though?
xx = torch.div((torch.sum(Kx)), (nx * nx)) | ||
yy = torch.div((torch.sum(Ky)), (ny * ny)) | ||
# one-sample U-statistic. | ||
if use_1sample_U: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a no-op? That is, regardless of whether this boolean is true or false, we do the same calculation?
ny = Ky.shape[0] | ||
is_unbiased = True | ||
if is_unbiased: | ||
xx = torch.div((torch.sum(Kx) - torch.sum(torch.diag(Kx))), (nx * (nx - 1))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe note that these calculations correspond to the calculation of
return fealant | ||
|
||
|
||
class DeepMmdLoss(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's worth recording the paper that this is being taken from somewhere here 🙂
# one-sample U-statistic. | ||
if use_1sample_U: | ||
xy = torch.div((torch.sum(Kxy) - torch.sum(torch.diag(Kxy))), (nx * (ny - 1))) | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm having trouble figuring out where the calculation below is described in their paper. It doesn't seem like it's a well defined calculation to be honest. I'm tempted to suggest we remove it and only offer the is_unbiased
boolean to switch between the estimates
else: | ||
xy = torch.div(torch.sum(Kxy), (nx * ny)) | ||
mmd2 = xx - 2 * xy + yy | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's worth noting that we calculate
xy = torch.div((torch.sum(Kxy) - torch.sum(torch.diag(Kxy))), (nx * (ny - 1))) | ||
else: | ||
xy = torch.div(torch.sum(Kxy), (nx * ny)) | ||
mmd2 = xx - 2 * xy + yy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This calculation can be moved outside the if and else blocks rather than being repeated in both blocks.
mmd2 = xx - 2 * xy + yy | ||
if not is_var_computed: | ||
return mmd2, None, Kxyxy | ||
hh = Kx + Ky - Kxy - Kxy.transpose(0, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be more consistent with the paper notation, I'd rename hh
to h_ij
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd add a comment above that we're computing the estimate in Equation (5) from the paper (without the lambda shift).
if not is_var_computed: | ||
return mmd2, None, Kxyxy | ||
hh = Kx + Ky - Kxy - Kxy.transpose(0, 1) | ||
V1 = torch.dot(hh.sum(1) / ny, hh.sum(1) / ny) / ny |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a weird way to write this. My suggestion, to make it look more like Equation (5) from the paper is to do
V1 = (4.0/ny**3)*(torch.dot(hh.sum(1), hh.sum(1)))
V2 = (4.0/nx**4)*(hh.sum()**2)
variance_estimate = V1 - V2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed in a comment below, but I would move the addition of the lambda shift into this calculation to better align it with Equation (5) from the paper. That is
variance_estimate = V1 - V2 + (10 ** (-8))
# ------------------------------ | ||
# Train deep network for MMD-D | ||
# ------------------------------ | ||
# Initialize optimizer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't really initializing the optimizer. It's just clearing the gradients.
if mmd_var_temp is None: | ||
raise AssertionError("Error: Variance of MMD is not computed. Please set is_var_computed=True.") | ||
mmd_std_temp = torch.sqrt(mmd_var_temp + 10 ** (-8)) | ||
STAT_u = torch.div(-1 * mmd_value_temp, mmd_std_temp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm okay with calling this stat_u
, but I would kill the all caps 🙂.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would add a comment that we're trying to maximize the ratio (mmd_value_stimate/mmd_std_estimate), so we multiply by -1 to be compatible with gradient descent.
ep = torch.exp(self.epsilonOPT) / (1 + torch.exp(self.epsilonOPT)) | ||
sigma = self.sigmaOPT**2 | ||
sigma0_u = self.sigma0OPT**2 | ||
# Compute Compute J (STAT_u) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd add to this comment that what we're forming is the ratio in Equation (4) from the paper in stat_u
which is referred to as
sigma = self.sigmaOPT**2 | ||
sigma0_u = self.sigma0OPT**2 | ||
# Compute Compute J (STAT_u) | ||
mmd_value_temp, mmd_var_temp, _ = self.MMDu( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure why they are suffixing with _temp
here. Maybe we just call it ..._estimate
?
) | ||
if mmd_var_temp is None: | ||
raise AssertionError("Error: Variance of MMD is not computed. Please set is_var_computed=True.") | ||
mmd_std_temp = torch.sqrt(mmd_var_temp + 10 ** (-8)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm good with doing the sqrt here, since its how we form
mmd_std_estimate = torch.sqrt(mmd_var_estimate)
STAT_u.backward() | ||
# Update weights using gradient descent | ||
self.optimizer_F.step() | ||
return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Return is unnecessary here.
sigma = self.sigmaOPT**2 | ||
sigma0_u = self.sigma0OPT**2 | ||
# Compute Compute J (STAT_u) | ||
mmd_value_temp, _, _ = self.MMDu( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As with the training function, I'd suggest renaming this to mmd_value_estimate
from _temp
checkpointer=checkpointer, | ||
lam=lam, | ||
deep_mmd_loss_weight=deep_mmd_loss_weight, | ||
flatten_feature_extraction_layers={key: True for key in size_feature_extraction_layers}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe also form this dictionary above the super call as you're doing with size_feature_extraction_layers
?
"--mu", | ||
action="store", | ||
type=float, | ||
help="Weight for the mkmmd losses", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Deep MMD losses 🙂
help="Weight for the mkmmd losses", | ||
required=False, | ||
) | ||
parser.add_argument( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this argument is being used anywhere?
required=False, | ||
default=1, | ||
) | ||
parser.add_argument( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly, I don't think this arg is used anywhere either
# client_side_learning_rate_value \ | ||
# lambda value \ | ||
# mu value \ | ||
# depp_mmd_loss_depth value \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Johnny Depp MMD Loss!
# /h/demerson/vector_repositories/fl4health_env/ \ | ||
# 0.0001 \ | ||
# 0.01 \ | ||
# 0.1 \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the depth argument is missing in this example?
PR Type
[Feature]
Short Description
This is a tentative implementation for deep mmd loss.
Tests Added
No tests added yet.