Skip to content

Commit

Permalink
Introduce metrics module with new loss options
Browse files Browse the repository at this point in the history
- Move all data dimensionality constants to constants.py
- Add note about ccai_paper_2023 branch
- Create metrics module with common interface
- Introduce nll and crps_gauss metrics, also available as losses
- Introduce output_std option to let models output std.-devs.
- Change validation metric from mae to rmse
- Change definition of rmse to take sqrt after spatial averaging only
- Add possibility to watch specific metrics and log as scalars, specified in constants.py
  • Loading branch information
joeloskarsson committed Jan 9, 2024
1 parent 2378ed7 commit c14b6b4
Show file tree
Hide file tree
Showing 7 changed files with 481 additions and 190 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ If you use Neural-LAM in your work, please cite:
year={2023}
}
```
As the code in the repository is continuously evolving, the latest version might feature some small differences to what was used in the paper.
See the branch [`ccai_paper_2023`](https://github.com/joeloskarsson/neural-lam/tree/ccai_paper_2023) for a revision of the code that reproduces the workshop paper.

We plan to continue updating this repository as we improve existing models and develop new ones.
Collaborations around this implementation are very welcome.
Expand Down
17 changes: 17 additions & 0 deletions neural_lam/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@
# Log prediction error for these lead times
val_step_log_errors = np.array([1, 2, 3, 5, 10, 15, 19])

# Log these metrics to wandb as scalar values for specific variables and lead times
# List of metrics to watch, including any prefix (e.g. val_rmse)
metrics_watch = [
]
# Dict with variables and lead times to log watched metrics for
# Format is a dictionary that maps from a variable index to a list of lead time steps
var_leads_metrics_watch = {
6: [2, 19], # t_2
14: [2, 19], # wvint_0
15: [2, 19], # z_1000
}

# Variable names
param_names = [
'pres_heightAboveGround_0_instant',
Expand Down Expand Up @@ -95,3 +107,8 @@
central_latitude=lambert_proj_params['lat_0'],
standard_parallels=(lambert_proj_params['lat_1'],
lambert_proj_params['lat_2']))

# Data dimensions
batch_static_feature_dim = 1 # Only open water
grid_forcing_dim = 5*3 # 5 features for 3 time-step window
grid_state_dim = 17
2 changes: 1 addition & 1 deletion neural_lam/interaction_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def forward(self, send_rep, rec_rep, edge_rep):
"""
# Always concatenate to [rec_nodes, send_nodes] for propagation, but only
# aggregate to rec_nodes
node_reps = torch.cat((rec_rep, send_rep), dim=1)
node_reps = torch.cat((rec_rep, send_rep), dim=-2)
edge_rep_aggr, edge_diff = self.propagate(self.edge_index, x=node_reps,
edge_attr=edge_rep)
rec_diff = self.aggr_mlp(torch.cat((rec_rep, edge_rep_aggr), dim=-1))
Expand Down
218 changes: 218 additions & 0 deletions neural_lam/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
import torch

def get_metric(metric_name):
"""
Get a defined metric with given name
metric_name: str, name of the metric
Returns:
metric: function implementing the metric
"""
metric_name_lower = metric_name.lower()
assert metric_name_lower in DEFINED_METRICS, f"Unknown metric: {metric_name}"
return DEFINED_METRICS[metric_name_lower]

def mask_and_reduce_metric(metric_entry_vals, mask, average_grid, sum_vars):
"""
Masks and (optionally) reduces entry-wise metric values
(...,) is any number of batch dimensions, potentially different but broadcastable
metric_entry_vals: (..., N, d_state), prediction
mask: (N,), boolean mask describing which grid nodes to use in metric
average_grid: boolean, if grid dimension -2 should be reduced (mean over N)
sum_vars: boolean, if variable dimension -1 should be reduced (sum over d_state)
Returns:
metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), depending
on reduction arguments.
"""
# Only keep grid nodes in mask
if mask is not None:
metric_entry_vals = metric_entry_vals[...,mask,:] # (..., N', d_state)

# Optionally reduce last two dimensions
if average_grid: # Reduce grid first
metric_entry_vals = torch.mean(metric_entry_vals, dim=-2) # (..., d_state)
if sum_vars: # Reduce vars second
metric_entry_vals = torch.sum(metric_entry_vals, dim=-1) # (..., N) or (...,)

return metric_entry_vals

def wmse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
"""
Weighted Mean Squared Error
(...,) is any number of batch dimensions, potentially different but broadcastable
pred: (..., N, d_state), prediction
target: (..., N, d_state), target
pred_std: (..., N, d_state) or (d_state,), predicted std.-dev.
mask: (N,), boolean mask describing which grid nodes to use in metric
average_grid: boolean, if grid dimension -2 should be reduced (mean over N)
sum_vars: boolean, if variable dimension -1 should be reduced (sum over d_state)
Returns:
metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), depending
on reduction arguments.
"""
entry_mse = torch.nn.functional.mse_loss(pred, target,
reduction='none') # (..., N, d_state)
entry_mse_weighted = entry_mse / (pred_std**2) # (..., N, d_state)

return mask_and_reduce_metric(entry_mse_weighted, mask=mask,
average_grid=average_grid, sum_vars=sum_vars)

def mse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
"""
(Unweighted) Mean Squared Error
(...,) is any number of batch dimensions, potentially different but broadcastable
pred: (..., N, d_state), prediction
target: (..., N, d_state), target
pred_std: (..., N, d_state) or (d_state,), predicted std.-dev.
mask: (N,), boolean mask describing which grid nodes to use in metric
average_grid: boolean, if grid dimension -2 should be reduced (mean over N)
sum_vars: boolean, if variable dimension -1 should be reduced (sum over d_state)
Returns:
metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), depending
on reduction arguments.
"""
# Replace pred_std with constant ones
return wmse(pred, target, torch.ones_like(pred_std), mask, average_grid, sum_vars)

def rmse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
"""
Root Mean Squared Error
Note: here take sqrt only after spatial averaging, averaging the RMSE of forecasts.
This is consistent with Weatherbench and others.
Because of this, averaging over grid must be set to true.
(...,) is any number of batch dimensions, potentially different but broadcastable
pred: (..., N, d_state), prediction
target: (..., N, d_state), target
pred_std: (..., N, d_state) or (d_state,), predicted std.-dev.
mask: (N,), boolean mask describing which grid nodes to use in metric
average_grid: boolean, if grid dimension -2 should be reduced (mean over N)
sum_vars: boolean, if variable dimension -1 should be reduced (sum over d_state)
Returns:
metric_val: One of (...,), (..., d_state), depending on reduction arguments.
"""
assert average_grid, "Can not compute RMSE without averaging grid"

# Spatially averaged mse, masking is also performed here
averaged_mse = mse(pred, target, pred_std, mask, average_grid=True,
sum_vars=False) # (..., d_state)
entry_rmse = torch.sqrt(averaged_mse) # (..., d_state)

# Optionally sum over variables here manually
if sum_vars:
return torch.sum(entry_rmse, dim=-1) # (...,)

return entry_rmse # (..., d_state)

def wmae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
"""
Weighted Mean Absolute Error
(...,) is any number of batch dimensions, potentially different but broadcastable
pred: (..., N, d_state), prediction
target: (..., N, d_state), target
pred_std: (..., N, d_state) or (d_state,), predicted std.-dev.
mask: (N,), boolean mask describing which grid nodes to use in metric
average_grid: boolean, if grid dimension -2 should be reduced (mean over N)
sum_vars: boolean, if variable dimension -1 should be reduced (sum over d_state)
Returns:
metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), depending
on reduction arguments.
"""
entry_mae = torch.nn.functional.l1_loss(pred, target,
reduction='none') # (..., N, d_state)
entry_mae_weighted = entry_mae / pred_std # (..., N, d_state)

return mask_and_reduce_metric(entry_mae_weighted, mask=mask,
average_grid=average_grid, sum_vars=sum_vars)

def mae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
"""
(Unweighted) Mean Absolute Error
(...,) is any number of batch dimensions, potentially different but broadcastable
pred: (..., N, d_state), prediction
target: (..., N, d_state), target
pred_std: (..., N, d_state) or (d_state,), predicted std.-dev.
mask: (N,), boolean mask describing which grid nodes to use in metric
average_grid: boolean, if grid dimension -2 should be reduced (mean over N)
sum_vars: boolean, if variable dimension -1 should be reduced (sum over d_state)
Returns:
metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), depending
on reduction arguments.
"""
# Replace pred_std with constant ones
return wmae(pred, target, torch.ones_like(pred_std), mask, average_grid, sum_vars)

def nll(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
"""
Negative Log Likelihood loss, for isotropic Gaussian likelihood
(...,) is any number of batch dimensions, potentially different but broadcastable
pred: (..., N, d_state), prediction
target: (..., N, d_state), target
pred_std: (..., N, d_state) or (d_state,), predicted std.-dev.
mask: (N,), boolean mask describing which grid nodes to use in metric
average_grid: boolean, if grid dimension -2 should be reduced (mean over N)
sum_vars: boolean, if variable dimension -1 should be reduced (sum over d_state)
Returns:
metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), depending
on reduction arguments.
"""
# Broadcast pred_std if shaped (d_state,), done internally in Normal class
dist = torch.distributions.Normal(pred, pred_std) # (..., N, d_state)
entry_nll = -dist.log_prob(target) # (..., N, d_state)

return mask_and_reduce_metric(entry_nll, mask=mask, average_grid=average_grid,
sum_vars=sum_vars)

def crps_gauss(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
"""
(Negative) Continuous Ranked Probability Score (CRPS)
Closed-form expression based on Gaussian predictive distribution
(...,) is any number of batch dimensions, potentially different but broadcastable
pred: (..., N, d_state), prediction
target: (..., N, d_state), target
pred_std: (..., N, d_state) or (d_state,), predicted std.-dev.
mask: (N,), boolean mask describing which grid nodes to use in metric
average_grid: boolean, if grid dimension -2 should be reduced (mean over N)
sum_vars: boolean, if variable dimension -1 should be reduced (sum over d_state)
Returns:
metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), depending
on reduction arguments.
"""
std_normal = torch.distributions.Normal(torch.zeros((), device=pred.device),
torch.ones((), device=pred.device))
target_standard = (target - pred)/pred_std # (..., N, d_state)

entry_crps = -pred_std*(
torch.pi**(-0.5)
-2*torch.exp(std_normal.log_prob(target_standard))
-target_standard*(2*std_normal.cdf(target_standard) - 1)
) # (..., N, d_state)

return mask_and_reduce_metric(entry_crps, mask=mask, average_grid=average_grid,
sum_vars=sum_vars)

DEFINED_METRICS = {
"mse": mse,
"mae": mae,
"rmse": rmse,
"wmse": wmse,
"wmae": wmae,
"nll": nll,
"crps_gauss": crps_gauss,
}
Loading

0 comments on commit c14b6b4

Please sign in to comment.