Skip to content

Commit

Permalink
uh. still need netcdf for xarray to read netcdf files.
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Apr 25, 2024
1 parent ff057d6 commit 432fe55
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 5 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies = [
"NREL-farms>=1.0.4",
"dask>=2022.0",
"google-auth-oauthlib==0.5.3",
"h5netcdf",
"matplotlib>=3.1",
"numpy>=1.7.0",
"pandas>=2.0",
Expand Down
6 changes: 4 additions & 2 deletions sup3r/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,8 @@ def train_epoch(self,

if only_gen or (train_gen and not gen_too_good):
trained_gen = True
b_loss_details = self.run_gradient_descent(
b_loss_details = self.timer(
self.run_gradient_descent,
batch.low_res,
batch.high_res,
self.generator_weights,
Expand All @@ -700,7 +701,8 @@ def train_epoch(self,

if only_disc or (train_disc and not disc_too_good):
trained_disc = True
b_loss_details = self.run_gradient_descent(
b_loss_details = self.timer(
self.run_gradient_descent,
batch.low_res,
batch.high_res,
self.discriminator_weights,
Expand Down
2 changes: 2 additions & 0 deletions sup3r/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from enum import Enum

import dask
import h5netcdf
import numpy as np
import pandas as pd
import phygnn
Expand All @@ -22,6 +23,7 @@
'nrel-rex': rex.__version__,
'python': sys.version,
'xarray': xarray.__version__,
'h5netcdf': h5netcdf.__version__,
'dask': dask.__version__,
}

Expand Down
10 changes: 7 additions & 3 deletions tests/bias/test_bias_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
from scipy import stats

from sup3r import CONFIG_DIR, TEST_DATA_DIR
from sup3r.bias.bias_calc import (LinearCorrection, MonthlyLinearCorrection,
SkillAssessment)
from sup3r.bias.bias_calc import (
LinearCorrection,
MonthlyLinearCorrection,
SkillAssessment,
)
from sup3r.bias.bias_transforms import local_linear_bc, monthly_local_linear_bc
from sup3r.models import Sup3rGan
from sup3r.pipeline.forward_pass import ForwardPass, ForwardPassStrategy
Expand Down Expand Up @@ -327,7 +330,8 @@ def test_fwp_integration():
os.path.join(TEST_DATA_DIR, 'zg_test.nc')]

lat_lon = DataHandlerNCforCC(input_files, features=[], target=target,
shape=shape).lat_lon
shape=shape,
worker_kwargs={'max_workers': 1}).lat_lon

Sup3rGan.seed()
model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4)
Expand Down

0 comments on commit 432fe55

Please sign in to comment.