Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update mov run #994

Merged
merged 18 commits into from
Nov 30, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
use xarray data array to plot map with cartopy, and adjust 1e20 as na…
…n to prevent error from shapely
  • Loading branch information
lee1043 committed Nov 30, 2023
commit 327d169382f622802a723bb03be8ed7df95bf5ff
144 changes: 81 additions & 63 deletions pcmdi_metrics/variability_mode/lib/plot_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
import xarray as xr
from cartopy.feature import LAND as cartopy_land
from cartopy.feature import OCEAN as cartopy_ocean
from cartopy.mpl.gridliner import LATITUDE_FORMATTER, LONGITUDE_FORMATTER
Expand Down Expand Up @@ -96,41 +97,51 @@ def plot_map(
maskout = None

if mode in ["AMO", "AMO_teleconnection"]:
center_lon_global = 0
central_longitude = 0
else:
center_lon_global = 180
central_longitude = 180

# Convert cdms variable to xarray
lons = eof_Nth.getLongitude()
lats = eof_Nth.getLatitude()
data = np.array(eof_Nth)
lon = np.array(lons)
lat = np.array(lats)
lon, lat = np.meshgrid(lon, lat)
data_array = xr.DataArray(np.array(data), coords={'lon': lon[0, :], 'lat': lat[:, 0]}, dims=('lat', 'lon'))
data_array = data_array.where(data_array != 1e20, np.nan)

plot_map_cartopy(
eof_Nth,
data_array,
output_file_name,
title=plot_title,
proj=projection,
gridline=gridline,
levels=levels,
maskout=maskout,
center_lon_global=center_lon_global,
central_longitude=central_longitude,
debug=debug,
)


def plot_map_cartopy(
data,
filename,
data_array,
filename=None,
title=None,
gridline=True,
levels=None,
proj="PlateCarree",
data_area="global",
cmap="RdBu_r",
center_lon_global=180,
central_longitude=180,
maskout=None,
debug=False,
):
"""
Parameters
----------
data : trainsisent variable
2D cdms2 TransientVariable with lat/lon coordinates attached.
data : data_array
2D xarray DataArray with lat/lon coordinates attached.
filename : str
Output file name (it is okay to omit '.png')
title : str, optional
Expand All @@ -153,39 +164,41 @@ def plot_map_cartopy(

debug_print("plot_map_cartopy starts", debug)

lons = data.getLongitude()
lats = data.getLatitude()

min_lon = min(lons)
max_lon = max(lons)
min_lat = min(lats)
max_lat = max(lats)
lon = data_array.lon
lat = data_array.lat

# Determine the extent based on the longitude range where data exists
lon_min = lon.min().item()
lon_max = lon.max().item()
lat_min = lat.min().item()
lat_max = lat.max().item()

if debug:
print(min_lon, max_lon, min_lat, max_lat)
print(lon_min, lon_max, lat_min, lat_max)

debug_print("Central longitude setup starts", debug)
debug_print("proj: " + proj, debug)
# map types example:
# https://github.com/SciTools/cartopy-tutorial/blob/master/tutorial/projections_crs_and_terms.ipynb

if proj == "PlateCarree":
projection = ccrs.PlateCarree(central_longitude=center_lon_global)
projection = ccrs.PlateCarree(central_longitude=central_longitude)
elif proj == "Robinson":
projection = ccrs.Robinson(central_longitude=center_lon_global)
projection = ccrs.Robinson(central_longitude=central_longitude)
elif proj == "Stereo_north":
projection = ccrs.NorthPolarStereo()
elif proj == "Stereo_south":
projection = ccrs.SouthPolarStereo()
elif proj == "Lambert":
max_lat = min(max_lat, 80)
lat_max = min(lat_max, 80)
if debug:
print("revised maxlat:", max_lat)
central_longitude = (min_lon + max_lon) / 2.0
central_latitude = (min_lat + max_lat) / 2.0
print("revised maxlat:", lat_max)
central_longitude = (lon_min + lon_max) / 2.0
central_latitude = (lat_min + lat_max) / 2.0
projection = ccrs.AlbersEqualArea(
central_longitude=central_longitude,
central_latitude=central_latitude,
standard_parallels=(20, max_lat),
central_longitude=central_longitude,
central_latitude=central_latitude,
standard_parallels=(20, lat_max),
)
else:
print("Error: projection not defined!")
Expand All @@ -195,22 +208,10 @@ def plot_map_cartopy(
print("projection:", projection)

# Generate plot
debug_print("Generate plot starts", debug)
fig = plt.figure(figsize=(8, 6))
debug_print("fig done", debug)
# ax = plt.axes(projection=projection)
ax = plt.axes(projection=ccrs.NorthPolarStereo())
debug_print("ax done", debug)
im = ax.contourf(
lons,
lats,
np.array(data),
transform=ccrs.PlateCarree(),
cmap=cmap,
levels=levels,
extend="both",
)
debug_print("contourf done", debug)
fig, ax = plt.subplots(subplot_kw={'projection': projection}, figsize=(8, 6))
debug_print("fig, ax done", debug)

# Add coastlines
ax.coastlines()
debug_print("Generate plot completed", debug)

Expand Down Expand Up @@ -263,12 +264,15 @@ def plot_map_cartopy(
# the bottom left and go round anticlockwise, creating a boundary point
# every 1 degree so that the result is smooth:
# https://stackoverflow.com/questions/43463643/cartopy-albersequalarea-limit-region-using-lon-and-lat

vertices = [
(lon - 180, min_lat) for lon in range(int(min_lon), int(max_lon + 1), 1)
] + [(lon - 180, max_lat) for lon in range(int(max_lon), int(min_lon - 1), -1)]
(lon - 180, lat_min) for lon in range(int(lon_min), int(lon_max + 1), 1)
] + [(lon - 180, lat_max) for lon in range(int(lon_max), int(lon_min - 1), -1)]
boundary = mpath.Path(vertices)
ax.set_boundary(boundary, transform=ccrs.PlateCarree(central_longitude=180))
ax.set_extent([min_lon, max_lon, min_lat, max_lat], crs=ccrs.PlateCarree())
ax.set_boundary(boundary, transform=ccrs.PlateCarree(central_longitude=central_longitude))

ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())

if gridline:
gl = ax.gridlines(
draw_labels=True,
Expand All @@ -282,26 +286,26 @@ def plot_map_cartopy(
gl.xlocator = mticker.FixedLocator([120, 160, 200 - 360, 240 - 360])
gl.top_labels = False # suppress top labels
# suppress right labels
# gl.right_labels = False
gl.right_labels = False
for ea in gl.ylabel_artists:
right_label = ea.get_position()[0] > 0
if right_label:
ea.set_visible(False)

debug_print("projection completed", debug)

# Add title
plt.title(title, pad=15, fontsize=15)

# Add colorbar
posn = ax.get_position()
cbar_ax = fig.add_axes([0, 0, 0.1, 0.1])
cbar_ax.set_position([posn.x0 + posn.width + 0.01, posn.y0, 0.01, posn.height])
cbar = plt.colorbar(im, cax=cbar_ax)
cbar.ax.tick_params(labelsize=10)

if proj == "PlateCarree":
ax.set_aspect("auto", adjustable=None)

# Plot contours from the data
im = ax.contourf(
lon,
lat,
data_array,
levels=levels,
cmap=cmap,
extend="both",
transform=ccrs.PlateCarree(),
)
debug_print("contourf done", debug)

# Maskout
if maskout is not None:
if maskout == "land":
Expand All @@ -312,8 +316,22 @@ def plot_map_cartopy(
ax.add_feature(
cartopy_ocean, zorder=100, edgecolor="k", facecolor="lightgrey"
)
if proj == "PlateCarree":
ax.set_aspect("auto", adjustable=None)

# Add title
ax.set_title(title, pad=15, fontsize=15)

# Add colorbar
posn = ax.get_position()
cbar_ax = fig.add_axes([0, 0, 0.1, 0.1])
cbar_ax.set_position([posn.x0 + posn.width + 0.01, posn.y0, 0.01, posn.height])
cbar = plt.colorbar(im, cax=cbar_ax)
cbar.ax.tick_params(labelsize=10)

# Done, save figure
debug_print("plot done, save figure as " + filename, debug)
fig.savefig(filename)
plt.close("all")
if filename is not None:
debug_print("plot done, save figure as " + filename, debug)
fig.savefig(filename)

plt.close("all")