forked from mllam/neural-lam
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce metrics module with new loss options
- Move all data dimensionality constants to constants.py - Add note about ccai_paper_2023 branch - Create metrics module with common interface - Introduce nll and crps_gauss metrics, also available as losses - Introduce output_std option to let models output std.-devs. - Change validation metric from mae to rmse - Change definition of rmse to take sqrt after spatial averaging only - Add possibility to watch specific metrics and log as scalars, specified in constants.py
- Loading branch information
1 parent
2378ed7
commit c14b6b4
Showing
7 changed files
with
481 additions
and
190 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
import torch | ||
|
||
def get_metric(metric_name): | ||
""" | ||
Get a defined metric with given name | ||
metric_name: str, name of the metric | ||
Returns: | ||
metric: function implementing the metric | ||
""" | ||
metric_name_lower = metric_name.lower() | ||
assert metric_name_lower in DEFINED_METRICS, f"Unknown metric: {metric_name}" | ||
return DEFINED_METRICS[metric_name_lower] | ||
|
||
def mask_and_reduce_metric(metric_entry_vals, mask, average_grid, sum_vars): | ||
""" | ||
Masks and (optionally) reduces entry-wise metric values | ||
(...,) is any number of batch dimensions, potentially different but broadcastable | ||
metric_entry_vals: (..., N, d_state), prediction | ||
mask: (N,), boolean mask describing which grid nodes to use in metric | ||
average_grid: boolean, if grid dimension -2 should be reduced (mean over N) | ||
sum_vars: boolean, if variable dimension -1 should be reduced (sum over d_state) | ||
Returns: | ||
metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), depending | ||
on reduction arguments. | ||
""" | ||
# Only keep grid nodes in mask | ||
if mask is not None: | ||
metric_entry_vals = metric_entry_vals[...,mask,:] # (..., N', d_state) | ||
|
||
# Optionally reduce last two dimensions | ||
if average_grid: # Reduce grid first | ||
metric_entry_vals = torch.mean(metric_entry_vals, dim=-2) # (..., d_state) | ||
if sum_vars: # Reduce vars second | ||
metric_entry_vals = torch.sum(metric_entry_vals, dim=-1) # (..., N) or (...,) | ||
|
||
return metric_entry_vals | ||
|
||
def wmse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): | ||
""" | ||
Weighted Mean Squared Error | ||
(...,) is any number of batch dimensions, potentially different but broadcastable | ||
pred: (..., N, d_state), prediction | ||
target: (..., N, d_state), target | ||
pred_std: (..., N, d_state) or (d_state,), predicted std.-dev. | ||
mask: (N,), boolean mask describing which grid nodes to use in metric | ||
average_grid: boolean, if grid dimension -2 should be reduced (mean over N) | ||
sum_vars: boolean, if variable dimension -1 should be reduced (sum over d_state) | ||
Returns: | ||
metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), depending | ||
on reduction arguments. | ||
""" | ||
entry_mse = torch.nn.functional.mse_loss(pred, target, | ||
reduction='none') # (..., N, d_state) | ||
entry_mse_weighted = entry_mse / (pred_std**2) # (..., N, d_state) | ||
|
||
return mask_and_reduce_metric(entry_mse_weighted, mask=mask, | ||
average_grid=average_grid, sum_vars=sum_vars) | ||
|
||
def mse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): | ||
""" | ||
(Unweighted) Mean Squared Error | ||
(...,) is any number of batch dimensions, potentially different but broadcastable | ||
pred: (..., N, d_state), prediction | ||
target: (..., N, d_state), target | ||
pred_std: (..., N, d_state) or (d_state,), predicted std.-dev. | ||
mask: (N,), boolean mask describing which grid nodes to use in metric | ||
average_grid: boolean, if grid dimension -2 should be reduced (mean over N) | ||
sum_vars: boolean, if variable dimension -1 should be reduced (sum over d_state) | ||
Returns: | ||
metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), depending | ||
on reduction arguments. | ||
""" | ||
# Replace pred_std with constant ones | ||
return wmse(pred, target, torch.ones_like(pred_std), mask, average_grid, sum_vars) | ||
|
||
def rmse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): | ||
""" | ||
Root Mean Squared Error | ||
Note: here take sqrt only after spatial averaging, averaging the RMSE of forecasts. | ||
This is consistent with Weatherbench and others. | ||
Because of this, averaging over grid must be set to true. | ||
(...,) is any number of batch dimensions, potentially different but broadcastable | ||
pred: (..., N, d_state), prediction | ||
target: (..., N, d_state), target | ||
pred_std: (..., N, d_state) or (d_state,), predicted std.-dev. | ||
mask: (N,), boolean mask describing which grid nodes to use in metric | ||
average_grid: boolean, if grid dimension -2 should be reduced (mean over N) | ||
sum_vars: boolean, if variable dimension -1 should be reduced (sum over d_state) | ||
Returns: | ||
metric_val: One of (...,), (..., d_state), depending on reduction arguments. | ||
""" | ||
assert average_grid, "Can not compute RMSE without averaging grid" | ||
|
||
# Spatially averaged mse, masking is also performed here | ||
averaged_mse = mse(pred, target, pred_std, mask, average_grid=True, | ||
sum_vars=False) # (..., d_state) | ||
entry_rmse = torch.sqrt(averaged_mse) # (..., d_state) | ||
|
||
# Optionally sum over variables here manually | ||
if sum_vars: | ||
return torch.sum(entry_rmse, dim=-1) # (...,) | ||
|
||
return entry_rmse # (..., d_state) | ||
|
||
def wmae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): | ||
""" | ||
Weighted Mean Absolute Error | ||
(...,) is any number of batch dimensions, potentially different but broadcastable | ||
pred: (..., N, d_state), prediction | ||
target: (..., N, d_state), target | ||
pred_std: (..., N, d_state) or (d_state,), predicted std.-dev. | ||
mask: (N,), boolean mask describing which grid nodes to use in metric | ||
average_grid: boolean, if grid dimension -2 should be reduced (mean over N) | ||
sum_vars: boolean, if variable dimension -1 should be reduced (sum over d_state) | ||
Returns: | ||
metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), depending | ||
on reduction arguments. | ||
""" | ||
entry_mae = torch.nn.functional.l1_loss(pred, target, | ||
reduction='none') # (..., N, d_state) | ||
entry_mae_weighted = entry_mae / pred_std # (..., N, d_state) | ||
|
||
return mask_and_reduce_metric(entry_mae_weighted, mask=mask, | ||
average_grid=average_grid, sum_vars=sum_vars) | ||
|
||
def mae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): | ||
""" | ||
(Unweighted) Mean Absolute Error | ||
(...,) is any number of batch dimensions, potentially different but broadcastable | ||
pred: (..., N, d_state), prediction | ||
target: (..., N, d_state), target | ||
pred_std: (..., N, d_state) or (d_state,), predicted std.-dev. | ||
mask: (N,), boolean mask describing which grid nodes to use in metric | ||
average_grid: boolean, if grid dimension -2 should be reduced (mean over N) | ||
sum_vars: boolean, if variable dimension -1 should be reduced (sum over d_state) | ||
Returns: | ||
metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), depending | ||
on reduction arguments. | ||
""" | ||
# Replace pred_std with constant ones | ||
return wmae(pred, target, torch.ones_like(pred_std), mask, average_grid, sum_vars) | ||
|
||
def nll(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): | ||
""" | ||
Negative Log Likelihood loss, for isotropic Gaussian likelihood | ||
(...,) is any number of batch dimensions, potentially different but broadcastable | ||
pred: (..., N, d_state), prediction | ||
target: (..., N, d_state), target | ||
pred_std: (..., N, d_state) or (d_state,), predicted std.-dev. | ||
mask: (N,), boolean mask describing which grid nodes to use in metric | ||
average_grid: boolean, if grid dimension -2 should be reduced (mean over N) | ||
sum_vars: boolean, if variable dimension -1 should be reduced (sum over d_state) | ||
Returns: | ||
metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), depending | ||
on reduction arguments. | ||
""" | ||
# Broadcast pred_std if shaped (d_state,), done internally in Normal class | ||
dist = torch.distributions.Normal(pred, pred_std) # (..., N, d_state) | ||
entry_nll = -dist.log_prob(target) # (..., N, d_state) | ||
|
||
return mask_and_reduce_metric(entry_nll, mask=mask, average_grid=average_grid, | ||
sum_vars=sum_vars) | ||
|
||
def crps_gauss(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): | ||
""" | ||
(Negative) Continuous Ranked Probability Score (CRPS) | ||
Closed-form expression based on Gaussian predictive distribution | ||
(...,) is any number of batch dimensions, potentially different but broadcastable | ||
pred: (..., N, d_state), prediction | ||
target: (..., N, d_state), target | ||
pred_std: (..., N, d_state) or (d_state,), predicted std.-dev. | ||
mask: (N,), boolean mask describing which grid nodes to use in metric | ||
average_grid: boolean, if grid dimension -2 should be reduced (mean over N) | ||
sum_vars: boolean, if variable dimension -1 should be reduced (sum over d_state) | ||
Returns: | ||
metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), depending | ||
on reduction arguments. | ||
""" | ||
std_normal = torch.distributions.Normal(torch.zeros((), device=pred.device), | ||
torch.ones((), device=pred.device)) | ||
target_standard = (target - pred)/pred_std # (..., N, d_state) | ||
|
||
entry_crps = -pred_std*( | ||
torch.pi**(-0.5) | ||
-2*torch.exp(std_normal.log_prob(target_standard)) | ||
-target_standard*(2*std_normal.cdf(target_standard) - 1) | ||
) # (..., N, d_state) | ||
|
||
return mask_and_reduce_metric(entry_crps, mask=mask, average_grid=average_grid, | ||
sum_vars=sum_vars) | ||
|
||
DEFINED_METRICS = { | ||
"mse": mse, | ||
"mae": mae, | ||
"rmse": rmse, | ||
"wmse": wmse, | ||
"wmae": wmae, | ||
"nll": nll, | ||
"crps_gauss": crps_gauss, | ||
} |
Oops, something went wrong.