Skip to content

Commit

Permalink
heart, models.geodetic: introduce GeodeticDataset id for handleing da…
Browse files Browse the repository at this point in the history
…taset names and components for GNSS correctly
  • Loading branch information
hvasbath committed May 2, 2024
1 parent 349ba91 commit a33f7fb
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 36 deletions.
2 changes: 1 addition & 1 deletion beat/apps/beat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2184,7 +2184,7 @@ def setup(parser):
except (UserIOWarning, KeyError):
raise ImportError("Full resolution data could not be loaded!")
elif isinstance(dataset, heart.GNSSCompoundComponent):
logger.info("Found GNSS Compound %s, importing to kite..." % dataset.name)
logger.info("Found GNSS Compound %s, importing to kite..." % dataset.id)
scene = dataset.to_kite_scene()
# scene.spool()
sandbox.setReferenceScene(scene)
Expand Down
4 changes: 2 additions & 2 deletions beat/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def do_import(self, dataset):
return dataset.covariance.data
else:
raise ValueError(
"Data covariance for dataset %s needs to be defined!" % dataset.name
"Data covariance for dataset %s needs to be defined!" % dataset.id
)

def do_non_toeplitz(self, dataset, result):
Expand All @@ -204,7 +204,7 @@ def do_non_toeplitz(self, dataset, result):
if num.isnan(scaling).any():
raise ValueError(
"Estimated Non-Toeplitz covariance matrix for dataset %s contains Nan! "
"Please increase 'max_dist_perc'!" % dataset.name
"Please increase 'max_dist_perc'!" % dataset.id
)

return scaling
Expand Down
8 changes: 8 additions & 0 deletions beat/heart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,6 +1074,10 @@ def __init__(self, **kwargs):
self.corrections = None
super(GeodeticDataset, self).__init__(**kwargs)

@property
def id(self):
return self.name

def get_corrections(self, hierarchicals, point=None):
"""
Needs to be specified on inherited dataset classes.
Expand Down Expand Up @@ -1177,6 +1181,10 @@ def __init__(self, **kwargs):
self._station2index = None
super(GNSSCompoundComponent, self).__init__(**kwargs)

@property
def id(self):
return "%s_%s" % (self.name, self.component)

def update_los_vector(self):
if self.component == "east":
c = num.array([0, 1, 0])
Expand Down
29 changes: 8 additions & 21 deletions beat/models/geodetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,24 +131,11 @@ def init_weights(self):
def n_t(self):
return len(self.datasets)

def get_all_dataset_names(self, hp_name):
def get_all_dataset_ids(self, hp_name):
"""
Return unique GNSS stations and radar acquisitions.
"""
names = []
for dataset in self.datasets:
if dataset.typ == hp_name.split("_")[1]:
if isinstance(dataset, heart.DiffIFG):
names.append(dataset.name)
elif isinstance(dataset, heart.GNSSCompoundComponent):
names.append(dataset.component)
else:
TypeError(
'Geodetic Dataset of class "%s" not '
"supported" % dataset.__class__.__name__
)

return names
return [dataset.id for dataset in self.datasets]

def analyse_noise(self, tpoint=None):
"""
Expand All @@ -166,7 +153,7 @@ def analyse_noise(self, tpoint=None):
for dataset, result in zip(self.datasets, results):
logger.info(
'Retrieving geodetic data-covariances with structure "%s" '
"for %s ..." % (self.config.noise_estimator.structure, dataset.name)
"for %s ..." % (self.config.noise_estimator.structure, dataset.id)
)

cov_d_geodetic = self.noise_analyser.get_data_covariance(
Expand Down Expand Up @@ -196,7 +183,7 @@ def get_hypersize(self, hp_name=""):
-------
int
"""
n_datasets = len(self.get_all_dataset_names(hp_name))
n_datasets = len(self.get_all_dataset_ids(hp_name))
if n_datasets == 0:
raise ConfigInconsistentError(
'Found no data for hyperparameter "%s". Please either load'
Expand Down Expand Up @@ -276,7 +263,7 @@ def get_filename(attr, ending="csv"):
return os.path.join(
results_path,
"{}_{}_{}.{}".format(
os.path.splitext(dataset.name)[0], attr, stage_number, ending
os.path.splitext(dataset.id)[0], attr, stage_number, ending
),
)

Expand Down Expand Up @@ -506,7 +493,7 @@ def get_variance_reductions(self, point, results=None, weights=None):
logger.debug("nom %f, denom %f" % (float(nom), float(denom)))
var_red = 1 - (nom / denom)

logger.debug("Variance reduction for %s is %f" % (dataset.name, var_red))
logger.debug("Variance reduction for %s is %f" % (dataset.id, var_red))

if 0:
from matplotlib import pyplot as plt
Expand All @@ -516,7 +503,7 @@ def get_variance_reductions(self, point, results=None, weights=None):
plt.colorbar(im)
plt.show()

var_reds[dataset.name] = var_red
var_reds[dataset.id] = var_red

return var_reds

Expand Down Expand Up @@ -548,7 +535,7 @@ def get_standardized_residuals(self, point, results=None, weights=None):
point, dataset, counter, hp_specific=hp_specific
)
choli = num.linalg.inv(dataset.covariance.chol(num.exp(hp * 2.0)))
stdz_residuals[dataset.name] = choli.dot(result.processed_res)
stdz_residuals[dataset.id] = choli.dot(result.processed_res)

return stdz_residuals

Expand Down
4 changes: 2 additions & 2 deletions beat/plotting/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,7 @@ def plot_covariances(datasets, covariances):
if i_l == 0:
ax.set_ylabel("Sample idx")
ax.set_xlabel("Sample idx")
ax.set_title(dataset.name)
ax.set_title(dataset.id)

cbaxes = fig.add_axes([cbl, cbb, cbw, cbh])
cblabel = "Covariance [m²]"
Expand All @@ -896,7 +896,7 @@ def plot_covariances(datasets, covariances):
cbs.set_label(cblabel, fontsize=fontsize)
else:
logger.info(
'Did not find "%s" covariance component for %s', attr, dataset.name
'Did not find "%s" covariance component for %s', attr, dataset.id
)
fig.delaxes(ax)

Expand Down
20 changes: 10 additions & 10 deletions beat/plotting/geodetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,11 @@ def gnss_fits(problem, stage, plot_options):
bvar_reductions_comp = {}
for dataset in dataset_to_result.keys():
target_var_reds = []
target_bvar_red = bvar_reductions[dataset.name]
target_bvar_red = bvar_reductions[dataset.id]
target_var_reds.append(target_bvar_red)
bvar_reductions_comp[dataset.component] = target_bvar_red * 100.0
for var_reds in ens_var_reductions:
target_var_reds.append(var_reds[dataset.name])
target_var_reds.append(var_reds[dataset.id])

all_var_reductions[dataset.component] = num.array(target_var_reds) * 100.0

Expand Down Expand Up @@ -387,7 +387,7 @@ def gnss_fits(problem, stage, plot_options):
Z = 0

out_filename = "/tmp/histbounds.txt"
in_rows = num.atleast_2d(all_var_reductions[dataset.component]).T
in_rows = num.atleast_2d(var_reductions_ens).T

m.gmt.pshistogram(
in_rows=in_rows,
Expand Down Expand Up @@ -543,7 +543,7 @@ def scene_fits(problem, stage, plot_options):

if po.plot_projection == "individual":
for result, dataset in zip(bresults_tmp, composite.datasets):
result.processed_res = stdz_residuals[dataset.name]
result.processed_res = stdz_residuals[dataset.id]

bvar_reductions = composite.get_variance_reductions(
bpoint, weights=composite.weights, results=bresults_tmp
Expand Down Expand Up @@ -599,11 +599,11 @@ def scene_fits(problem, stage, plot_options):
all_var_reductions = {}
for dataset in dataset_to_result.keys():
target_var_reds = []
target_var_reds.append(bvar_reductions[dataset.name])
target_var_reds.append(bvar_reductions[dataset.id])
for var_reds in ens_var_reductions:
target_var_reds.append(var_reds[dataset.name])
target_var_reds.append(var_reds[dataset.id])

all_var_reductions[dataset.name] = num.array(target_var_reds) * 100.0
all_var_reductions[dataset.id] = num.array(target_var_reds) * 100.0

figures = []
axes = []
Expand Down Expand Up @@ -892,7 +892,7 @@ def draw_sources(ax, sources, scene, po, event, **kwargs):
vmin = -dcolims[tidx]
vmax = dcolims[tidx]
logger.debug(
"Variance of residual for %s is: %f", dataset.name, datavec.var()
"Variance of residual for %s is: %f", dataset.id, datavec.var()
)
else:
vmin = -colims[tidx]
Expand Down Expand Up @@ -956,8 +956,8 @@ def draw_sources(ax, sources, scene, po, event, **kwargs):
if po.nensemble > 1:
in_ax = plot_inset_hist(
axs[2],
data=num.atleast_2d(all_var_reductions[dataset.name]),
best_data=bvar_reductions[dataset.name] * 100.0,
data=num.atleast_2d(all_var_reductions[dataset.id]),
best_data=bvar_reductions[dataset.id] * 100.0,
linewidth=1.0,
bbox_to_anchor=(0.75, 0.775, 0.25, 0.225),
labelsize=6,
Expand Down

0 comments on commit a33f7fb

Please sign in to comment.