Skip to content

Commit

Permalink
plotting.marginals: correlation_hist finalizing mixed source type sup…
Browse files Browse the repository at this point in the history
…port
  • Loading branch information
hvasbath committed Apr 30, 2024
1 parent 2db8d7c commit 3b8154a
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 75 deletions.
97 changes: 45 additions & 52 deletions beat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
import logging
import os
from collections import OrderedDict
from typing import Dict as TDict
from typing import List as TList

import numpy as num
from pyrocko import gf, model, trace, util
from pyrocko.cake import load_model
from pyrocko.gf import RectangularSource as PyrockoRS
from pyrocko.gf.seismosizer import Cloneable, LocalEngine
from pyrocko.gf.seismosizer import Cloneable
from pyrocko.guts import (
ArgumentError,
Bool,
Expand All @@ -30,13 +32,11 @@
dump,
load,
)
from typing import Dict as TDict
from typing import List as TList
from pytensor import config as tconfig

from theano import config as tconfig

from beat import utility, bem
from beat import utility
from beat.covariance import available_noise_structures, available_noise_structures_2d
from beat.defaults import default_decimation_factors, defaults
from beat.heart import (
ArrivalTaper,
Filter,
Expand All @@ -45,20 +45,23 @@
ReferenceLocation,
_domain_choices,
)
from beat.defaults import default_decimation_factors, defaults
from beat.sources import (
RectangularSource,
stf_catalog,
source_catalog as geometry_source_catalog,
)
from beat.sources import RectangularSource, stf_catalog
from beat.sources import source_catalog as geometry_source_catalog
from beat.utility import check_point_keys, list2string

logger = logging.getLogger("config")


try:
from beat.bem import source_catalog as bem_source_catalog

bem_catalog = {"geodetic": bem_source_catalog}
except ImportError:
logger.warning(
"To enable 'bem' mode packages 'pygmsh' and 'cutde' need to be installed."
)
bem_catalog = {}
bem_source_catalog = {}


source_catalog = {}
Expand All @@ -68,8 +71,6 @@

guts_prefix = "beat"

logger = logging.getLogger("config")

stf_names = stf_catalog.keys()
all_source_names = list(source_catalog.keys()) + list(bem_source_catalog.keys())

Expand Down Expand Up @@ -1199,8 +1200,8 @@ def get_traction_field(self, discretized_sources):


class BEMConfig(MediumConfig):
nu = Float.T(default=0.25, help="Poisson's ratio")
mu = Float.T(default=33e9, help="Shear modulus [Pa]")
poissons_ratio = Float.T(default=0.25, help="Poisson's ratio")
shear_modulus = Float.T(default=33e9, help="Shear modulus [Pa]")
earth_model_name = String.T(default="homogeneous-elastic-halfspace")
mesh_size = Float.T(
default=0.5,
Expand All @@ -1227,11 +1228,10 @@ def get_parameter(variable, nvars=1, lower=1, upper=2):


class DatatypeParameterMapping(Object):

sources_variables = List.T(Dict.T(String.T(), Int.T()))
n_sources = Int.T()

def __init__(self, **kwargs):

Object.__init__(self, **kwargs)

self._mapping = None
Expand All @@ -1242,12 +1242,15 @@ def __getitem__(self, k):
self.point_to_sources_mapping()

if k not in self._mapping.keys():
raise KeyError(k)
raise KeyError("Parameters mapping does not contain parameters:", k)

return self._mapping[k]

def point_to_sources_mapping(self) -> TDict[str, TList[int]]:

"""
Mapping for mixed source setups. Mapping source parameter name to source indexes.
Is used by utilit.split_point to split the full point into subsource_points.
"""
if self._mapping is None:
start_idx = 0
total_variables = {}
Expand All @@ -1270,7 +1273,6 @@ def point_variable_names(self) -> TList[int]:
return self.point_to_sources_mapping().keys()

def total_variables_sizes(self) -> TDict[str, int]:

mapping = self.point_to_sources_mapping()
variables_sizes = {}
for variable, idxs in mapping.items():
Expand All @@ -1285,12 +1287,11 @@ class SourcesParameterMapping(Object):
"""

source_types = List.T(String.T(), default=[])
n_sources = List.T(String.T(), default=[])
n_sources = List.T(Int.T(), default=[])
datatypes = List.T(StringChoice.T(choices=_datatype_choices), default=[])
mappings = Dict.T(String.T(), DatatypeParameterMapping.T())

def __init__(self, **kwargs):

Object.__init__(self, **kwargs)

for datatype in self.datatypes:
Expand All @@ -1299,7 +1300,7 @@ def __init__(self, **kwargs):
def add(self, sources_variables: TDict = {}, datatype: str = "geodetic"):
if datatype in self.mappings:
self.mappings[datatype] = DatatypeParameterMapping(
sources_variables=sources_variables
sources_variables=sources_variables, n_sources=sum(self.n_sources)
)
else:
raise ValueError(
Expand Down Expand Up @@ -1509,36 +1510,29 @@ def get_random_variables(self):
Returns
-------
rvs : dict
variable random variables
random variable names and their kwargs
fixed_params : dict
fixed random parameters
"""
from pymc3 import Uniform

logger.debug("Optimization for %s sources", list2string(self.n_sources))

rvs = {}
fixed_params = {}
for param in self.priors.values():
if not num.array_equal(param.lower, param.upper):
shape = self.get_parameter_shape(param)
size = self.get_parameter_size(param)

kwargs = dict(
name=param.name,
shape=num.sum(shape),
lower=param.get_lower(shape),
upper=param.get_upper(shape),
testval=param.get_testvalue(shape),
shape=(num.sum(size),),
lower=param.get_lower(size),
upper=param.get_upper(size),
initval=param.get_testvalue(size),
transform=None,
dtype=tconfig.floatX,
)
try:
rvs[param.name] = Uniform(**kwargs)

except TypeError:
kwargs.pop("name")
rvs[param.name] = Uniform.dist(**kwargs)

rvs[param.name] = kwargs
else:
logger.info(
f"not solving for {param.name}, got fixed at {utility.list2string(param.lower.flatten())}"
Expand Down Expand Up @@ -1592,7 +1586,7 @@ def _validate_parameters(self, dict_name=None):
double_check.append(name)
else:
raise ValueError(
"Parameter %s not unique in %s!".format(name, dict_name)
"Parameter %s not unique in %s!" % (name, dict_name)
)

logger.info(f"All {dict_name} ok!")
Expand Down Expand Up @@ -1631,8 +1625,8 @@ def get_test_point(self):
"""
test_point = {}
for varname, var in self.priors.items():
shape = self.get_parameter_shape(var)
test_point[varname] = var.get_testvalue(shape)
size = self.get_parameter_size(var)
test_point[varname] = var.get_testvalue(size)

for varname, var in self.hyperparameters.items():
test_point[varname] = var.get_testvalue()
Expand All @@ -1642,20 +1636,20 @@ def get_test_point(self):

return test_point

def get_parameter_shape(self, param):
def get_parameter_size(self, param):
if self.mode == ffi_mode_str and param.name in hypo_vars:
shape = self.n_sources[0]
size = self.n_sources[0]
elif self.mode == ffi_mode_str and self.mode_config.npatches:
shape = self.mode_config.subfault_npatches
if len(shape) == 0:
shape = self.mode_config.npatches
size = self.mode_config.subfault_npatches
if len(size) == 0:
size = self.mode_config.npatches
elif self.mode in [ffi_mode_str, geometry_mode_str, bem_mode_str]:
shape = param.dimension
size = param.dimension

else:
raise TypeError(f"Mode not implemented: {self.mode}")

return shape
return size

def get_derived_variables_shapes(self):
"""
Expand Down Expand Up @@ -2173,7 +2167,6 @@ def init_dataset_config(config, datatype, mode):
c.project_dir = os.path.join(os.path.abspath(main_path), name)

if mode in [geometry_mode_str, bem_mode_str]:

for datatype in datatypes:
init_dataset_config(c, datatype=datatype, mode=mode)

Expand Down Expand Up @@ -2216,14 +2209,14 @@ def init_dataset_config(config, datatype, mode):
' "geometry" mode: "%s"!' % (source_types[0], geometry_source_type)
)

n_sources = gmc.problem_config.n_sources[0]
n_sources = gmc.problem_config.n_sources
point = {k: v.testvalue for k, v in gmc.problem_config.priors.items()}
point = utility.adjust_point_units(point)
source_points = utility.split_point(point, n_sources_total=n_sources)
source_points = utility.split_point(point, n_sources_total=n_sources[0])

reference_sources = init_reference_sources(
source_points,
n_sources,
n_sources[0],
geometry_source_type,
gmc.problem_config.stf_type,
event=gmc.event,
Expand Down
42 changes: 19 additions & 23 deletions beat/plotting/marginals.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ def correlation_plot(
def correlation_plot_hist(
mtrace,
varnames=None,
source_param_dicts=None,
mapping=None,
figsize=None,
hist_color=None,
cmap=None,
Expand All @@ -627,8 +627,7 @@ def correlation_plot_hist(
Mutlitrace instance containing the sampling results
varnames : list of variable names
Variables to be plotted, if None all variable are plotted
source_param_dicts: list of dict
of parameters and indexes to trace arrays
mapping: ...
figsize : figure size tuple
If None, size is (12, num of variables * 2) inch
cmap : matplotlib colormap
Expand Down Expand Up @@ -679,17 +678,23 @@ def correlation_plot_hist(
figs = []
axes = []

print(source_param_dicts)
# min_source_ixs = {
# varname: int(min(idxs)) for varname, idxs in source_param_dicts.items()}
point_to_sources = mapping.point_to_sources_mapping()
source_param_dicts = utility.split_point(
point_to_sources,
point_to_sources=point_to_sources,
n_sources_total=sum(mapping.n_sources),
)
min_source_ixs = {
varname: int(min(idxs)) for varname, idxs in point_to_sources.items()
}

for source_i, param_dict in enumerate(source_param_dicts):
logger.info("for variables of source %i ..." % source_i)
hist_ylims = []
print(source_i, param_dict)
source_varnames = list(param_dict.keys())
weeded_source_varnames = [
varname for varname in source_varnames if varname in varnames
varname for varname in varnames if varname in source_varnames
]
nvar = len(weeded_source_varnames)

Expand All @@ -711,15 +716,10 @@ def correlation_plot_hist(

for i_k in range(nvar):
v_namea = weeded_source_varnames[i_k]
source_i_a = int(param_dict[v_namea]) # - min_source_ixs[v_namea]
source_i_a = int(param_dict[v_namea]) - min_source_ixs[v_namea]
print("source_i_a", source_i_a, v_namea)
try:
a = d[v_namea][:, source_i_a]
except IndexError:
source_i_a -= 1
a = d[v_namea][:, source_i_a]


a = d[v_namea][:, source_i_a]
for i_l in range(i_k, nvar):
ax = axs[i_l, i_k]
v_nameb = weeded_source_varnames[i_l]
Expand Down Expand Up @@ -753,7 +753,7 @@ def correlation_plot_hist(
xlim = ax.get_xlim()
hist_ylims.append(ax.get_ylim())
else:
source_i_b = int(param_dict[v_namea])
source_i_b = int(param_dict[v_nameb]) - min_source_ixs[v_nameb]
print("v_nameb", v_nameb, source_i_b)
try:
b = d[v_nameb][:, source_i_b]
Expand Down Expand Up @@ -820,7 +820,9 @@ def correlation_plot_hist(

if unify:
varnames_repeat_x = [
var_reap for varname in weeded_source_varnames for var_reap in (varname,) * nvar
var_reap
for varname in weeded_source_varnames
for var_reap in (varname,) * nvar
]
varnames_repeat_y = weeded_source_varnames * nvar
unitiesx = unify_tick_intervals(
Expand Down Expand Up @@ -1006,17 +1008,11 @@ def draw_correlation_hist(problem, plot_options):
datatype = problem.config.problem_config.datatypes[0]

mapping = problem.composites[datatype].mapping
point_to_sources = mapping.point_to_sources_mapping()
source_param_dicts = utility.split_point(
point_to_sources,
point_to_sources=point_to_sources,
n_sources_total=2,
)

figs, _ = correlation_plot_hist(
mtrace=stage.mtrace,
varnames=varnames,
source_param_dicts=source_param_dicts,
mapping=mapping,
cmap=plt.cm.gist_earth_r,
chains=None,
point=reference,
Expand Down

0 comments on commit 3b8154a

Please sign in to comment.