Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Driver cleanup #483

Merged
merged 3 commits into from
Mar 15, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions src/python/pcmdi/scripts/driver/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import logging
import cdutil
import cdms2
import pcmdi_metrics.io.base
from pcmdi_metrics.io.base import Base


class DataSet(object):
''' Abstract parent of the Observation of Model classes. '''
__metaclass__ = abc.ABCMeta

def __init__(self, parameter, var_name_long, region,
Expand All @@ -24,13 +25,15 @@ def __init__(self, parameter, var_name_long, region,
self.sftlf = sftlf

def get_sftlf(self):
''' Returns the sftlf attribute. '''
return self.sftlf

def __call__(self):
return self.get()

@staticmethod
def calculate_level_from_var(var):
''' Get the level from the var string, where it's var_LEVEL '''
var_split_name = var.split('_')
if len(var_split_name) > 1:
level = float(var_split_name[-1]) * 100
Expand All @@ -39,6 +42,8 @@ def calculate_level_from_var(var):
return level

def setup_target_grid(self, obs_or_model_file):
''' Call the set_target_grid function for
obs_or_model_file, which is of type Base. '''
if self.use_omon(self.obs_dict, self.var):
regrid_method = self.parameter.regrid_method_ocn
regrid_tool = self.parameter.regrid_tool_ocn
Expand All @@ -55,17 +60,19 @@ def setup_target_grid(self, obs_or_model_file):

@staticmethod
def use_omon(obs_dict, var):
''' For the given variable and obs_dict, do we use Omon? '''
obs_default = obs_dict[var][obs_dict[var]["default"]]
return obs_default["CMIP_CMOR_TABLE"] == 'Omon'

@staticmethod
def create_sftlf(parameter):
''' Create the sftlf file from the parameter. '''
sftlf = {}

for test in parameter.test_data_set:
sft = pcmdi_metrics.io.base.Base(parameter.test_data_path,
getattr(parameter, "sftlf_filename_template",
parameter.filename_template))
sft = Base(parameter.test_data_path,
getattr(parameter, "sftlf_filename_template",
parameter.filename_template))
sft.model_version = test
sft.table = "fx"
sft.realm = "atmos"
Expand Down Expand Up @@ -96,17 +103,19 @@ def create_sftlf(parameter):

@staticmethod
def apply_custom_keys(obj, custom_dict, var):
''' Apply the all of the keys in custom_dict that are var to obj. '''
for k, v in custom_dict.iteritems():
key = custom_dict[k]
setattr(obj, k, key.get(var, key.get(None, "")))

@abc.abstractmethod
def get(self):
"""Calls the get function on the Base object."""
''' Calls the get function on the Base object. '''
raise NotImplementedError()

@staticmethod
def load_path_as_file_obj(name):
''' Returns a File object for the file named name. '''
file_path = sys.prefix + '/share/pmp/' + name
opened_file = None
try:
Expand All @@ -120,9 +129,9 @@ def load_path_as_file_obj(name):

@abc.abstractmethod
def hash(self):
"""Calls the hash function on the Base object."""
''' Calls the hash function on the Base object. '''
raise NotImplementedError()

def file_path(self):
"""Calls the __call__() function on the Base object."""
''' Calls the __call__() function on the Base object. '''
raise NotImplementedError()
18 changes: 14 additions & 4 deletions src/python/pcmdi/scripts/driver/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import MV2
import cdutil
import cdms2
import pcmdi_metrics.io.base
from pcmdi_metrics.io.base import Base
import pcmdi_metrics.driver.dataset


class Model(pcmdi_metrics.driver.dataset.DataSet):
''' Handles all the computation (setting masking, target grid, etc)
and some file I/O related to models. '''
def __init__(self, parameter, var_name_long, region,
model, obs_dict, data_path, sftlf):
super(Model, self).__init__(parameter, var_name_long, region,
Expand All @@ -22,8 +24,8 @@ def __init__(self, parameter, var_name_long, region,
self.setup_target_mask()

def create_model_file(self):
self._model_file = pcmdi_metrics.io.base.Base(self.data_path,
self.parameter.filename_template)
''' Creates an object that will eventually output the netCDF file. '''
self._model_file = Base(self.data_path, self.parameter.filename_template)
self._model_file.variable = self.var
self._model_file.model_version = self.obs_or_model
self._model_file.period = self.parameter.period
Expand All @@ -34,6 +36,7 @@ def create_model_file(self):
self.parameter.custom_keys, self.var)

def setup_target_mask(self):
''' Sets the mask and target_mask attribute of self._model_file '''
self.var_in_file = self.get_var_in_file()

if self.region is not None:
Expand All @@ -47,6 +50,8 @@ def setup_target_mask(self):
MV2.not_equal(self.sftlf['target_grid'], region_value)

def get(self):
''' Gets the variable based on the region and level (if given) for
the file from data_path, which is defined in the initalizer. '''
try:
if self.level is None:
data_model = self._model_file.get(
Expand All @@ -64,6 +69,7 @@ def get(self):
raise RuntimeError('Need to skip model: %s' % self.obs_or_model)

def get_var_in_file(self):
''' Based off the model_tweaks parameter, get the variable mapping. '''
tweaks = {}
tweaks_all = {}
if hasattr(self.parameter, 'model_tweaks'):
Expand All @@ -80,8 +86,10 @@ def get_var_in_file(self):
return var_in_file

def create_sftlf_model_raw(self, var_in_file):
''' For the self.obs_or_model from the initializer, create a landSeaMask
from cdutil for self.sftlf[self.obs_or_model]['raw'] value. '''
if not hasattr(self.parameter, 'generate_sftlf') or \
self.parameter.generate_sftlf is False:
self.parameter.generate_sftlf is False:
logging.info('Model %s does not have sftlf, skipping region: %s' % (self.obs_or_model, self.region))
raise RuntimeError('Model %s does not have sftlf, skipping region: %s' % (self.obs_or_model, self.region))

Expand All @@ -98,7 +106,9 @@ def create_sftlf_model_raw(self, var_in_file):
logging.info('Auto generated sftlf for model %s' % self.obs_or_model)

def hash(self):
''' Return a hash of the file. '''
return self._model_file.hash()

def file_path(self):
''' Return the path of the file. '''
return self._model_file()
26 changes: 21 additions & 5 deletions src/python/pcmdi/scripts/driver/observation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging
import MV2
import pcmdi_metrics.io.base
import pcmdi_metrics.driver.dataset
from pcmdi_metrics.io.base import Base
from pcmdi_metrics.driver.dataset import DataSet


class OBS(pcmdi_metrics.io.base.Base):
class OBS(Base):
''' Creates an output the netCDF file for an observation. '''
def __init__(self, root, var, obs_dict, obs='default',
file_mask_template=None):
template = "%(realm)/%(frequency)/%(variable)/" +\
Expand Down Expand Up @@ -32,6 +33,8 @@ def __init__(self, root, var, obs_dict, obs='default',
self.variable = var

def setup_based_on_obs_table(self, obs_table):
''' Set the realm, frequency, ac based on the
CMIP_CMOR_TABLE value in the obs dict.'''
if obs_table == u'Omon':
self.realm = 'ocn'
self.frequency = 'mo'
Expand All @@ -46,7 +49,9 @@ def setup_based_on_obs_table(self, obs_table):
self.ac = 'ac'


class Observation(pcmdi_metrics.driver.dataset.DataSet):
class Observation(DataSet):
''' Handles all the computation (setting masking, target grid, etc)
and some file I/O related to observations. '''
def __init__(self, parameter, var_name_long, region,
obs, obs_dict, data_path, sftlf):
super(Observation, self).__init__(parameter, var_name_long, region,
Expand All @@ -58,6 +63,7 @@ def __init__(self, parameter, var_name_long, region,
self.setup_target_mask()

def create_obs_file(self):
''' Creates an object that will eventually output the netCDF file. '''
obs_mask_name = self.create_obs_mask_name()
self._obs_file = OBS(self.data_path, self.var,
self.obs_dict, self.obs_or_model,
Expand All @@ -67,6 +73,7 @@ def create_obs_file(self):
self._obs_file.case_id = self.parameter.case_id

def create_obs_mask_name(self):
''' Gets the name from the obs_mask, which is obtained from a netCDF file. '''
try:
obs_from_obs_dict = self.get_obs_from_obs_dict()
obs_mask = OBS(self.data_path, 'sftlf',
Expand All @@ -80,6 +87,8 @@ def create_obs_mask_name(self):
return obs_mask_name

def get_obs_from_obs_dict(self):
''' Returns the obsercation from the obsercation
dictionary for self.var and self.obs_or_model. '''
if isinstance(self.obs_dict[self.var][self.obs_or_model], (str, unicode)):
obs_from_obs_dict = \
self.obs_dict[self.var][self.obs_dict[self.var][self.obs_or_model]]
Expand All @@ -88,6 +97,7 @@ def get_obs_from_obs_dict(self):
return obs_from_obs_dict

def setup_target_mask(self):
''' Sets the attribute target_mask of self._obs_file. '''
if self.region is not None:
region_value = self.region.get('value', None)
if region_value is not None:
Expand All @@ -97,6 +107,8 @@ def setup_target_mask(self):
)

def get(self):
''' Gets the variable based on the region and level (if given) for
the file from data_path, which is defined in the initializer. '''
try:
if self.level is not None:
data_obs = self._obs_file.get(self.var,
Expand All @@ -115,14 +127,18 @@ def get(self):
self.var, self.obs_or_model, e)

def hash(self):
''' Return a hash of the file. '''
return self._obs_file.hash()

def file_path(self):
''' Return the path of the file. '''
return self._obs_file()

@staticmethod
# This must remain static b/c used before an Observation obj is created.
# This must remain static b/c used before an Observation object is created.
def setup_obs_list_from_parameter(parameter_obs_list, obs_dict, var):
''' If the data_set list from the parameter is
for observations, apply these special cases. '''
obs_list = parameter_obs_list
if 'all' in [x.lower() for x in obs_list]:
obs_list = 'all'
Expand Down
Loading