Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proof of concept: MeshIndexSet #6014

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 286 additions & 0 deletions lib/iris/experimental/ugrid/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -1962,6 +1962,292 @@ def topology_dimension(self):
return self._metadata_manager.topology_dimension


class MeshIndexSet(Mesh):
def __init__(self, mesh, location, indices):
self.super_mesh = mesh
self.location = location
self.indices = indices

self._metadata_manager = metadata_manager_factory(MeshMetadata)

# topology_dimension is read-only, so assign directly to the metadata manager
self._metadata_manager.topology_dimension = mesh.topology_dimension

self.node_dimension = mesh.node_dimension
self.edge_dimension = mesh.edge_dimension
self.face_dimension = mesh.face_dimension

# assign the metadata to the metadata manager
self.standard_name = mesh.standard_name
self.long_name = mesh.long_name
self.var_name = mesh.var_name
self.units = mesh.units
self.attributes = mesh.attributes

self._coord_manager = _MeshIndexCoordinateManager(mesh, location, indices)
self._connectivity_manager = _MeshIndexConnectivityManager(
mesh, location, indices
)

def __eq__(self, other):
# TBD: this is a minimalist implementation and requires to be revisited
return id(self) == id(other)

def __ne__(self, other):
# TBD: this is a minimalist implementation and requires to be revisited
return id(self) != id(other)


class _MeshIndexManager:
def __init__(self, mesh, location, indices):
self.mesh = mesh
self.location = location
self.indices = indices

self.face_indices = self._calculate_face_indices()
self.edge_indices = self._calculate_edge_indices()
self.node_indices = self._calculate_node_indices()
self.node_index_dict = {
old_index: new_index
for new_index, old_index in enumerate(self.node_indices)
}

def _calculate_node_indices(self):
if self.location == "node":
return self.indices
elif self.location == "edge":
connectivity = self.mesh.edge_node_connectivity[self.indices]
node_set = list(set(connectivity.indices.compressed()))
node_set.sort()
return node_set
elif self.location == "face":
connectivity = self.mesh.face_node_connectivity[self.indices]
node_set = list(set(connectivity.indices.compressed()))
node_set.sort()
return node_set

def _calculate_edge_indices(self):
if self.location != "edge":
return None
return self.indices

def _calculate_face_indices(self):
if self.location != "face":
return None
return self.indices


class _MeshIndexCoordinateManager(_MeshIndexManager):
REQUIRED = (
"node_x",
"node_y",
)
OPTIONAL = (
"edge_x",
"edge_y",
"face_x",
"face_y",
)

def __init__(self, mesh, location, indices):
super().__init__(mesh, location, indices)
self.ALL = self.REQUIRED + self.OPTIONAL
self._members = {}
self._members = {member: getattr(self, member) for member in self.ALL}

def __eq__(self, other):
# TBD: this is a minimalist implementation and requires to be revisited
return id(self) == id(other)

def __ne__(self, other):
# TBD: this is a minimalist implementation and requires to be revisited
return id(self) != id(other)

@property
def node_x(self):
if "node_x" in self._members:
return self._members["node_x"]
else:
return self.mesh._coord_manager.node_x[self.node_indices]

@property
def node_y(self):
if "node_y" in self._members:
return self._members["node_y"]
else:
return self.mesh._coord_manager.node_y[self.node_indices]

@property
def edge_x(self):
if "edge_x" in self._members:
return self._members["edge_x"]
else:
return self.mesh._coord_manager.edge_x[self.edge_indices]

@property
def edge_y(self):
if "edge_y" in self._members:
return self._members["edge_y"]
else:
return self.mesh._coord_manager.edge_y[self.edge_indices]

@property
def face_x(self):
if "face_x" in self._members:
return self._members["face_x"]
else:
return self.mesh._coord_manager.face_x[self.face_indices]

@property
def face_y(self):
if "face_y" in self._members:
return self._members["face_y"]
else:
return self.mesh._coord_manager.face_y[self.face_indices]

@property
def node_coords(self):
return MeshNodeCoords(node_x=self.node_x, node_y=self.node_y)

@property
def edge_coords(self):
return MeshEdgeCoords(edge_x=self.edge_x, edge_y=self.edge_y)

@property
def face_coords(self):
return MeshFaceCoords(face_x=self.face_x, face_y=self.face_y)

def filters(
self,
item=None,
standard_name=None,
long_name=None,
var_name=None,
attributes=None,
axis=None,
include_nodes=None,
include_edges=None,
include_faces=None,
):
# TBD: support coord_systems?

# Preserve original argument before modifying.
face_requested = include_faces

# Rationalise the tri-state behaviour.
args = [include_nodes, include_edges, include_faces]
state = not any(set(filter(lambda arg: arg is not None, args)))
include_nodes, include_edges, include_faces = map(
lambda arg: arg if arg is not None else state, args
)

def populated_coords(coords_tuple):
return list(filter(None, list(coords_tuple)))

members = []
if include_nodes:
members += populated_coords(self.node_coords)
if include_edges:
members += populated_coords(self.edge_coords)
if hasattr(self, "face_coords"):
if include_faces:
members += populated_coords(self.face_coords)
elif face_requested:
dmsg = "Ignoring request to filter non-existent 'face_coords'"
logger.debug(dmsg, extra=dict(cls=self.__class__.__name__))

result = metadata_filter(
members,
item=item,
standard_name=standard_name,
long_name=long_name,
var_name=var_name,
attributes=attributes,
axis=axis,
)

# Use the results to filter the _members dict for returning.
result_ids = [id(r) for r in result]
result_dict = {k: v for k, v in self._members.items() if id(v) in result_ids}
return result_dict

def filter(self, **kwargs):
# TODO: rationalise commonality with MeshConnectivityManager.filter and Cube.coord.
result = self.filters(**kwargs)

if len(result) > 1:
names = ", ".join(f"{member}={coord!r}" for member, coord in result.items())
emsg = (
f"Expected to find exactly 1 coordinate, but found {len(result)}. "
f"They were: {names}."
)
raise CoordinateNotFoundError(emsg)

if len(result) == 0:
item = kwargs["item"]
if item is not None:
if not isinstance(item, str):
item = item.name()
name = (
item
or kwargs["standard_name"]
or kwargs["long_name"]
or kwargs["var_name"]
or None
)
name = "" if name is None else f"{name!r} "
emsg = f"Expected to find exactly 1 {name}coordinate, but found none."
raise CoordinateNotFoundError(emsg)

return result


class _MeshIndexConnectivityManager(_MeshIndexManager):
@property
def edge_node(self):
if self.edge_indices is None:
return None
else:
connectivity = self.mesh.edge_node_connectivity[self.edge_indices]
connectivity_indices = np.vectorize(self.node_index_dict.get)(
connectivity.indices
)
connectivity = Connectivity(
connectivity_indices,
connectivity.cf_role,
standard_name=connectivity.standard_name,
long_name=connectivity.long_name,
var_name=connectivity.var_name,
units=connectivity.units,
attributes=connectivity.attributes,
start_index=connectivity.start_index,
location_axis=connectivity.location_axis,
)
return connectivity

@property
def face_node(self):
if self.face_indices is None:
return None
else:
connectivity = self.mesh.face_node_connectivity[self.face_indices]
connectivity_indices = np.vectorize(self.node_index_dict.get)(
connectivity.indices
)
connectivity = Connectivity(
connectivity_indices,
connectivity.cf_role,
standard_name=connectivity.standard_name,
long_name=connectivity.long_name,
var_name=connectivity.var_name,
units=connectivity.units,
attributes=connectivity.attributes,
start_index=connectivity.start_index,
location_axis=connectivity.location_axis,
)
return connectivity


class _Mesh1DCoordinateManager:
"""TBD: require clarity on coord_systems validation.

Expand Down
Loading