Skip to content

Commit

Permalink
Removed two classes and added lots of docstrings.
Browse files Browse the repository at this point in the history
  • Loading branch information
zshaheen committed Mar 14, 2017
1 parent 9f89eb4 commit b1f394a
Show file tree
Hide file tree
Showing 7 changed files with 275 additions and 277 deletions.
23 changes: 16 additions & 7 deletions src/python/pcmdi/scripts/driver/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import logging
import cdutil
import cdms2
import pcmdi_metrics.io.base
from pcmdi_metrics.io.base import Base


class DataSet(object):
''' Abstract parent of the Observation of Model classes. '''
__metaclass__ = abc.ABCMeta

def __init__(self, parameter, var_name_long, region,
Expand All @@ -24,13 +25,15 @@ def __init__(self, parameter, var_name_long, region,
self.sftlf = sftlf

def get_sftlf(self):
''' Returns the sftlf attribute. '''
return self.sftlf

def __call__(self):
return self.get()

@staticmethod
def calculate_level_from_var(var):
''' Get the level from the var string, where it's var_LEVEL '''
var_split_name = var.split('_')
if len(var_split_name) > 1:
level = float(var_split_name[-1]) * 100
Expand All @@ -39,6 +42,8 @@ def calculate_level_from_var(var):
return level

def setup_target_grid(self, obs_or_model_file):
''' Call the set_target_grid function for
obs_or_model_file, which is of type Base. '''
if self.use_omon(self.obs_dict, self.var):
regrid_method = self.parameter.regrid_method_ocn
regrid_tool = self.parameter.regrid_tool_ocn
Expand All @@ -55,17 +60,19 @@ def setup_target_grid(self, obs_or_model_file):

@staticmethod
def use_omon(obs_dict, var):
''' For the given variable and obs_dict, do we use Omon? '''
obs_default = obs_dict[var][obs_dict[var]["default"]]
return obs_default["CMIP_CMOR_TABLE"] == 'Omon'

@staticmethod
def create_sftlf(parameter):
''' Create the sftlf file from the parameter. '''
sftlf = {}

for test in parameter.test_data_set:
sft = pcmdi_metrics.io.base.Base(parameter.test_data_path,
getattr(parameter, "sftlf_filename_template",
parameter.filename_template))
sft = Base(parameter.test_data_path,
getattr(parameter, "sftlf_filename_template",
parameter.filename_template))
sft.model_version = test
sft.table = "fx"
sft.realm = "atmos"
Expand Down Expand Up @@ -96,17 +103,19 @@ def create_sftlf(parameter):

@staticmethod
def apply_custom_keys(obj, custom_dict, var):
''' Apply the all of the keys in custom_dict that are var to obj. '''
for k, v in custom_dict.iteritems():
key = custom_dict[k]
setattr(obj, k, key.get(var, key.get(None, "")))

@abc.abstractmethod
def get(self):
"""Calls the get function on the Base object."""
''' Calls the get function on the Base object. '''
raise NotImplementedError()

@staticmethod
def load_path_as_file_obj(name):
''' Returns a File object for the file named name. '''
file_path = sys.prefix + '/share/pmp/' + name
opened_file = None
try:
Expand All @@ -120,9 +129,9 @@ def load_path_as_file_obj(name):

@abc.abstractmethod
def hash(self):
"""Calls the hash function on the Base object."""
''' Calls the hash function on the Base object. '''
raise NotImplementedError()

def file_path(self):
"""Calls the __call__() function on the Base object."""
''' Calls the __call__() function on the Base object. '''
raise NotImplementedError()
18 changes: 14 additions & 4 deletions src/python/pcmdi/scripts/driver/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import MV2
import cdutil
import cdms2
import pcmdi_metrics.io.base
from pcmdi_metrics.io.base import Base
import pcmdi_metrics.driver.dataset


class Model(pcmdi_metrics.driver.dataset.DataSet):
''' Handles all the computation (setting masking, target grid, etc)
and some file I/O related to models. '''
def __init__(self, parameter, var_name_long, region,
model, obs_dict, data_path, sftlf):
super(Model, self).__init__(parameter, var_name_long, region,
Expand All @@ -22,8 +24,8 @@ def __init__(self, parameter, var_name_long, region,
self.setup_target_mask()

def create_model_file(self):
self._model_file = pcmdi_metrics.io.base.Base(self.data_path,
self.parameter.filename_template)
''' Creates an object that will eventually output the netCDF file. '''
self._model_file = Base(self.data_path, self.parameter.filename_template)
self._model_file.variable = self.var
self._model_file.model_version = self.obs_or_model
self._model_file.period = self.parameter.period
Expand All @@ -34,6 +36,7 @@ def create_model_file(self):
self.parameter.custom_keys, self.var)

def setup_target_mask(self):
''' Sets the mask and target_mask attribute of self._model_file '''
self.var_in_file = self.get_var_in_file()

if self.region is not None:
Expand All @@ -47,6 +50,8 @@ def setup_target_mask(self):
MV2.not_equal(self.sftlf['target_grid'], region_value)

def get(self):
''' Gets the variable based on the region and level (if given) for
the file from data_path, which is defined in the initalizer. '''
try:
if self.level is None:
data_model = self._model_file.get(
Expand All @@ -64,6 +69,7 @@ def get(self):
raise RuntimeError('Need to skip model: %s' % self.obs_or_model)

def get_var_in_file(self):
''' Based off the model_tweaks parameter, get the variable mapping. '''
tweaks = {}
tweaks_all = {}
if hasattr(self.parameter, 'model_tweaks'):
Expand All @@ -80,8 +86,10 @@ def get_var_in_file(self):
return var_in_file

def create_sftlf_model_raw(self, var_in_file):
''' For the self.obs_or_model from the initializer, create a landSeaMask
from cdutil for self.sftlf[self.obs_or_model]['raw'] value. '''
if not hasattr(self.parameter, 'generate_sftlf') or \
self.parameter.generate_sftlf is False:
self.parameter.generate_sftlf is False:
logging.info('Model %s does not have sftlf, skipping region: %s' % (self.obs_or_model, self.region))
raise RuntimeError('Model %s does not have sftlf, skipping region: %s' % (self.obs_or_model, self.region))

Expand All @@ -98,7 +106,9 @@ def create_sftlf_model_raw(self, var_in_file):
logging.info('Auto generated sftlf for model %s' % self.obs_or_model)

def hash(self):
''' Return a hash of the file. '''
return self._model_file.hash()

def file_path(self):
''' Return the path of the file. '''
return self._model_file()
26 changes: 21 additions & 5 deletions src/python/pcmdi/scripts/driver/observation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging
import MV2
import pcmdi_metrics.io.base
import pcmdi_metrics.driver.dataset
from pcmdi_metrics.io.base import Base
from pcmdi_metrics.driver.dataset import Dataset


class OBS(pcmdi_metrics.io.base.Base):
class OBS(Base):
''' Creates an output the netCDF file for an observation. '''
def __init__(self, root, var, obs_dict, obs='default',
file_mask_template=None):
template = "%(realm)/%(frequency)/%(variable)/" +\
Expand Down Expand Up @@ -32,6 +33,8 @@ def __init__(self, root, var, obs_dict, obs='default',
self.variable = var

def setup_based_on_obs_table(self, obs_table):
''' Set the realm, frequency, ac based on the
CMIP_CMOR_TABLE value in the obs dict.'''
if obs_table == u'Omon':
self.realm = 'ocn'
self.frequency = 'mo'
Expand All @@ -46,7 +49,9 @@ def setup_based_on_obs_table(self, obs_table):
self.ac = 'ac'


class Observation(pcmdi_metrics.driver.dataset.DataSet):
class Observation(Dataset):
''' Handles all the computation (setting masking, target grid, etc)
and some file I/O related to observations. '''
def __init__(self, parameter, var_name_long, region,
obs, obs_dict, data_path, sftlf):
super(Observation, self).__init__(parameter, var_name_long, region,
Expand All @@ -58,6 +63,7 @@ def __init__(self, parameter, var_name_long, region,
self.setup_target_mask()

def create_obs_file(self):
''' Creates an object that will eventually output the netCDF file. '''
obs_mask_name = self.create_obs_mask_name()
self._obs_file = OBS(self.data_path, self.var,
self.obs_dict, self.obs_or_model,
Expand All @@ -67,6 +73,7 @@ def create_obs_file(self):
self._obs_file.case_id = self.parameter.case_id

def create_obs_mask_name(self):
''' Gets the name from the obs_mask, which is obtained from a netCDF file. '''
try:
obs_from_obs_dict = self.get_obs_from_obs_dict()
obs_mask = OBS(self.data_path, 'sftlf',
Expand All @@ -80,6 +87,8 @@ def create_obs_mask_name(self):
return obs_mask_name

def get_obs_from_obs_dict(self):
''' Returns the obsercation from the obsercation
dictionary for self.var and self.obs_or_model. '''
if isinstance(self.obs_dict[self.var][self.obs_or_model], (str, unicode)):
obs_from_obs_dict = \
self.obs_dict[self.var][self.obs_dict[self.var][self.obs_or_model]]
Expand All @@ -88,6 +97,7 @@ def get_obs_from_obs_dict(self):
return obs_from_obs_dict

def setup_target_mask(self):
''' Sets the attribute target_mask of self._obs_file. '''
if self.region is not None:
region_value = self.region.get('value', None)
if region_value is not None:
Expand All @@ -97,6 +107,8 @@ def setup_target_mask(self):
)

def get(self):
''' Gets the variable based on the region and level (if given) for
the file from data_path, which is defined in the initializer. '''
try:
if self.level is not None:
data_obs = self._obs_file.get(self.var,
Expand All @@ -115,14 +127,18 @@ def get(self):
self.var, self.obs_or_model, e)

def hash(self):
''' Return a hash of the file. '''
return self._obs_file.hash()

def file_path(self):
''' Return the path of the file. '''
return self._obs_file()

@staticmethod
# This must remain static b/c used before an Observation obj is created.
# This must remain static b/c used before an Observation object is created.
def setup_obs_list_from_parameter(parameter_obs_list, obs_dict, var):
''' If the data_set list from the parameter is
for observations, apply these special cases. '''
obs_list = parameter_obs_list
if 'all' in [x.lower() for x in obs_list]:
obs_list = 'all'
Expand Down
Loading

0 comments on commit b1f394a

Please sign in to comment.